#!/usr/bin/env amspython

"""
First create stress-strain-curve.csv with stress_strain_curve.py

Then call this script with
$AMSBIN/amspython young_yield_poisson.py stress-strain-curve.csv

This script accompanies the tutorial at 
https://www.scm.com/doc/Tutorials/MolecularDynamicsAndMonteCarlo/PolymersMechanicalProperties.html
"""

import sys
import os
from scm.plams import *
import numpy as np 
import matplotlib.pyplot as plt
from scipy.stats import linregress
from scipy.interpolate import interp1d

def main():
    if len(sys.argv) != 2 or not os.path.exists(sys.argv[1]):
        print("\n USAGE: " + str(sys.argv[0]) + " stress-strain-curve.csv\n")     
        exit(1)

    A = np.loadtxt(sys.argv[1], skiprows=1) # load stress-strain-curve.csv
    strain_x, strain_y, strain_z, stress_xx, stress_yy, stress_zz = A[:,0], A[:,1], A[:,2], A[:,3], A[:,4], A[:,5]

    young_modulus, yield_point_strain, yield_point_stress = young_modulus_yield_point(strain_x, stress_xx, output_fname="young_modulus.png", zoomed_output_fname="zoomed_young_modulus.png")

    poisson_ratio = get_poisson_ratio(strain_x, strain_y, strain_z, output_fname="poisson_ratio.png")

    print("The Young's modulus is {:.2f} GPa".format(young_modulus))
    print("The yield point is at ({:.3f}, {:.3f} GPa)".format(yield_point_strain, yield_point_stress))
    print("Poisson's ratio is {:.2f}".format(poisson_ratio))
    print("Saved graphs in young_modulus.png and poisson_ratio.png")
    print("Saved strains and stresses to stress-strain-curve.csv")

def moving_average(x, w):
    """ Moving average of x with a window size of w """
    return np.convolve(x, np.ones(w), 'valid') / w

def find_all_intersections(x1, y1, x2, y2, n_points=1000):
    """
    Interpolates the two functions (x1, y1) and (x2, y2) on a grid with n_points points.
    Returns a 4-tuple:
        indices: a list of indices of the intersection points for the other return arrays
        xx: the n_points x values
        y1_interp: the n_points interpolated y1 values
        y2_iterp: the n_points interpolated y2 values
    """

    f1 = interp1d(x1, y1, kind = 'linear')
    f2 = interp1d(x2, y2, kind = 'linear')

    xx = np.linspace(max(min(x1), min(x2)), min(max(x1), max(x2)), n_points)

    y1_interp = f1(xx)
    y2_interp = f2(xx)

    indices = np.argwhere(np.diff(np.sign(y1_interp - y2_interp))).flatten()

    return indices, xx, y1_interp, y2_interp

def find_last_intersection(x1, y1, x2, y2):
    indices, xx, y1, y2 = find_all_intersections(x1, y1, x2, y2)
    last_index = indices[-1]
    return xx[last_index], (y1[last_index]+y2[last_index])/2.0

def young_modulus_yield_point(strain, stress, linear_fit_max=0.03, moving_avg_window=200, output_fname=None, zoomed_output_fname=None):
    """
        Calculates the Young's modulus and yield point
        Returns: a 3-tuple
            Young modulus (GPa)
            Yield point strain
            Yield point stress (GPa)
    """
    fit_indices = strain < linear_fit_max
    res = linregress(strain[fit_indices], stress[fit_indices])
    young_modulus = res.slope 

    linear_fit_x = strain
    linear_fit_y = res.slope*strain + res.intercept

    offset_linear_fit_x = linear_fit_x
    offset_linear_fit_y = res.slope*(strain-0.002) + res.intercept

    moving_avg_x = moving_average(strain, moving_avg_window)
    moving_avg_y = moving_average(stress, moving_avg_window)

    yield_point_x, yield_point_y = find_last_intersection(offset_linear_fit_x, offset_linear_fit_y, moving_avg_x, moving_avg_y)

    if output_fname is not None:
        plt.clf()
        plt.plot(strain, stress, 'o', marker='.', markersize=1.2)
        plt.plot(linear_fit_x, linear_fit_y, linewidth=1, label="Linear fit < {}".format(linear_fit_max))
        plt.plot(offset_linear_fit_x, offset_linear_fit_y, linewidth=1, label="0.2% offset line")
        plt.plot(moving_avg_x, moving_avg_y, linewidth=1, label="Moving average (window: {})".format(moving_avg_window))
        plt.xlabel("Axial strain x")
        plt.ylabel("Stress xx (GPa)")
        plt.title("Young's modulus: {:.2f} GPa".format(young_modulus))
        plt.legend()
        plt.savefig(output_fname)
        if zoomed_output_fname is not None:
            plt.xlim([0.02, 0.04])
            plt.ylim([0.1, 0.3])
            plt.title("Yield point: ({:.3f}, {:.3f} GPa)".format(yield_point_x, yield_point_y))
            plt.savefig(zoomed_output_fname)
        plt.clf()

    return young_modulus, yield_point_x, yield_point_y

def get_poisson_ratio(strain_axial, strain_trans_1, strain_trans_2, output_fname=None):
    """ Returns the poisson ratio """

    res1 = linregress(strain_axial, strain_trans_1)
    res2 = linregress(strain_axial, strain_trans_2)
    poisson_ratio = -(res1.slope+res2.slope)/2.0

    if output_fname is not None:
        plt.clf()
        plt.plot(strain_axial, strain_trans_2, 'o', marker='.', markersize=1.2, label="Strain z")
        plt.plot(strain_axial, res2.slope*strain_axial+res2.intercept, label="{:.2f}*x+{:.3f}".format(res2.slope, res2.intercept))
        plt.plot(strain_axial, strain_trans_1, 'o', marker='.', markersize=1.2, label="Strain y")
        plt.plot(strain_axial, res1.slope*strain_axial+res1.intercept, label="{:.2f}*x+{:.3f}".format(res1.slope, res1.intercept))
        plt.xlabel("Axial strain")
        plt.ylabel("Transverse strain")
        plt.title("Poisson's ratio: {:.2f}".format(poisson_ratio))
        plt.legend()
        plt.savefig(output_fname)
        plt.clf()

    return poisson_ratio

if __name__ == '__main__':
    main()
