import numpy as np
import pandas as pd
import os

import subprocess
import shlex
import time
from string import Template
from scipy.signal import find_peaks
from scipy.optimize import curve_fit

## ActualChamberPressurePa path
ActualChamberPressuremPa = "$SHML/ActualChamberPressuremPa"

## piezo voltage limits [V]
piezo_voltage_limits = (0, 30)

## controller update interval [s]
sample_time = 1.0


# Try to get ActualChamberPressurePa location from ENV
try:
    ActualChamberPressuremPa = Template(ActualChamberPressuremPa)
    ActualChamberPressuremPa = ActualChamberPressuremPa.substitute(os.environ)
except (ValueError, KeyError):
    # or guess the location
    ActualChamberPressuremPa = (
        "/golem/shm_golem/ActualSession/SessionLogBook/ActualChamberPressuremPa"
    )
    if not os.path.exists(ActualChamberPressuremPa):
        pass
        raise SystemExit(
            "Cannot find ActualChamberPressurePa in %s" % ActualChamberPressuremPa
        )


def set_piezo_valve_voltage(voltage: float):
    '''Set voltage on Piezo valve using bash scipt'''
    bash_path = 'WorkingGas.sh'
    function  = 'SetVoltage@GasValveTo'
    params    = '%.2f' % voltage
    run_bash_function(bash_path, function, params)


# def set_piezo_valve_voltage(voltage: float):
#     """Set voltage on Piezo valve using bash scipt"""
#     bash_path = None
#     function = "./Dirigent.sh  -HpV"
#     params = "%.2f" % voltage
#     run_bash_function(bash_path, function, params)


def run_bash_function(library_path, function_name, params):
    if library_path is None:
        params = shlex.split('"%s %s"' % (function_name, params))
    else:
        params = shlex.split(
            '"source %s; %s %s"' % (library_path, function_name, params)
        )
    cmdline = ["bash", "-c"] + params
    subprocess.Popen(cmdline, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)


last_read_pressure = 0.0


def get_current_pressure() -> float:
    """Current pressure in mPa"""
    global last_read_pressure

    with open(ActualChamberPressuremPa, "r") as f:
        ## sometimes we'll read empty string (bash is writing to file)
        ## of so serve last measeurement
        try:
            last_read_pressure = float(f.readline())
            return last_read_pressure
        except ValueError:
            return float("nan")


def linearize_piezo_valve(x):
    if x < 1:
        return 0
    return np.log(x)


def linearize_piezo_valve_inversion(x):
    return np.exp(x)


CALIB_FILE = "/golem/shm_golem/ActualSession/SessionLogBook/WG_PID_calibration_table4H2"
#CALIB_FILE = Template(CALIB_FILE)
#CALIB_FILE = CALIB_FILE.substitute(os.environ)


def write_PID_calib_vals(piezo_voltage, pressure, K_u, T_u):
    write_header = not os.path.isfile(CALIB_FILE)
    with open(CALIB_FILE, "a+") as f:
        if write_header:
            f.write("piezo_voltage\tpressure\tK_u\tlog_K_u\tT_u\n")

        f.write(
            f"{piezo_voltage:2.2f}\t{pressure:1.2e}\t{K_u:.4e}\t{np.log(K_u):.4e}\t{T_u:.4e}\n"
        )


counter = 0
start_time = time.time()


def update(df, voltage):
    global counter

    set_piezo_valve_voltage(voltage)
    pressure = get_current_pressure()
    timestamp = time.time() - start_time
    df.loc[counter] = [timestamp, pressure, voltage]
    counter += 1


def kickoff(kickoff_voltage, kickoff_time, relax_voltage):
    set_piezo_valve_voltage(kickoff_voltage)
    time.sleep(kickoff_time)
    set_piezo_valve_voltage(relax_voltage)


def analyze_periodical_waveform(data):
    # Find peaks (maximums)
    peaks, _ = find_peaks(data)
    maxima = data.iloc[peaks]

    # Find valleys (minimums)
    inverted_series = -data
    valleys, _ = find_peaks(inverted_series)
    minima = data.iloc[valleys]

    max_distances = np.diff(maxima.index)
    min_distances = np.diff(minima.index)
    max_period = np.mean(max_distances)
    min_period = np.mean(min_distances)
    period = np.mean(np.hstack([max_period, min_period]))

    return {
        "max": maxima.mean(),
        "min": minima.mean(),
        "period": period,
    }


# ultimate period is aroung 9-10 s, by default waint one period and test for 3 periods
def get_system_params(
    piezo_setpoint, relax_time=25, stabilization_time=15, test_time=35
):
    global counter
    start_time = time.time()
    counter = 0

    columns = ["time", "pressure", "piezo_voltage"]
    df = pd.DataFrame(columns=columns)

    while time.time() - start_time < relax_time:
        update(df, piezo_setpoint)
        time.sleep(sample_time)

    ## get relay switching point
    press_operation_point = df.pressure.max()
    press_midpoint = (df.pressure.max() - df.pressure.min()) / 2

    ## begin relay flip-flop (wait for system to stabilize)
    start_test_time = relax_time + stabilization_time
    while time.time() - start_time < start_test_time:
        if get_current_pressure() < press_midpoint:
            update(df, piezo_setpoint)
        else:
            update(df, 0)
        time.sleep(sample_time)

    ## colect values for testing
    while time.time() - start_time < stabilization_time + relax_time + test_time:
        if get_current_pressure() < press_midpoint:
            update(df, piezo_setpoint)
        else:
            update(df, 0)
        time.sleep(sample_time)

    # set voltage to zero on exit
    set_piezo_valve_voltage(0.0)

    ## save values for later
    try:
        df.to_csv(
            f'/tmp/pressure-test/freq_response_{piezo_setpoint:.2f}V_{time.strftime("%Y%m%d_%H%M%S")}.csv'
        )
    except Exception:
        pass

    df.set_index("time", inplace=True)
    test_df = df[start_test_time:].pressure

    res = analyze_periodical_waveform(test_df)
    amp2 = res["max"] - res["min"]
    T_u = res["period"]

    K_u = 4 * linearize_piezo_valve_inversion(piezo_setpoint) / (np.pi * amp2)

    print("PID Close loop calibaration results")
    print(f"{press_operation_point=:.2f}")
    print(f"{piezo_setpoint=:.2f}")
    print(f"{K_u=:.2e}")
    print(f"{T_u=:.2f}")

    write_PID_calib_vals(
        piezo_voltage=piezo_setpoint,
        pressure=press_operation_point,
        K_u=K_u,
        T_u=T_u,
    )


def main():
    from scipy.interpolate import interp1d
    CalData = np.loadtxt('/dev/shm/golem/ActualSession/SessionLogBook/WG_calibration_table4H2')
    Voltage=CalData[:,0]
    pressure=CalData[:,1]
    pressure_intertolation = interp1d(pressure,Voltage)

    pressure_calib_points = [5,15,25,40]
    if pressure.min() <= 5 and pressure.max() < 40:
        pressure_calib_points = [5,15,25,40]
    elif pressure.min() <= 7.5 and pressure.max() < 40:
        pressure_calib_points = [7.5,15,25,40]
    elif pressure.min() <= 10 and pressure.max() < 40:
        pressure_calib_points = [10,15,25,40]
    else:
        def press_func(x,a,b):
            return a*np.exp(b*x)
        pop,_ = curve_fit(press_func,Voltage,pressure)
        
        x = np.linspace(1,100,100)
        y = press_func(x, *pop)
        pressure_intertolation = interp1d(y,x)

    for i,calib_point in enumerate(pressure_calib_points):
        piezo_voltage = pressure_intertolation(calib_point)
        print(f'Running PID calibration on {piezo_voltage:.2f}V')
        get_system_params(piezo_setpoint=piezo_voltage)
        
        ## wait before next calibration
        if i != len(pressure_calib_points)-1:
            time.sleep(20)



if __name__ == "__main__":
    main()
