# simple-pid 2.0.0 https://pypi.org/project/simple-pid/
from simple_pid import PID
from simple_pid.pid import _clamp
from math import exp

import subprocess
import shlex
import time
from string import Template
import os
import signal
import sys
import numpy as np
import pandas as pd

def ziegler_nichols_method_close_loop(K_u, T_u, reg_type = 'PID'):
    '''Ziegler Nichols Method for Close loop (Frequency response)'''
    if reg_type == "P":
        return 0.5*K_u,0,0
    elif reg_type == "PI":
        return 0.45*K_u,0.54*K_u/T_u,0
    elif reg_type == "PID":
        return 0.6*K_u, 1.2*K_u/T_u,3*K_u*T_u/40
    elif reg_type == "some overshoot":
        return K_u/3, (2/3)*K_u/T_u,(1/9)*K_u*T_u
    elif reg_type == "no overshoot":
        return 0.2*K_u, 0.4*K_u/T_u,(2/30)*K_u*T_u
    
    # https://pages.mtu.edu/~tbco/cm416/zn.html
    elif reg_type == "Tyreus-Luyben PI":
        return K_u/3.2, K_u/(2.2*T_u), 0
    elif reg_type == "Tyreus-Luyben PID":
        return K_u/2.2, K_u/(2.2*T_u), K_u*T_u/6.3
    
    
    else:
        ValueError('wrong PID type')
            
def ziegler_nichols_method_open_loop(gain, time_const, delay, reg_type = 'PID'):
    '''Ziegler Nichols Method for Open loop (Step response)'''
    G = gain
    tau = time_const
    td = delay
    
    # https://www.ni.com/docs/en-US/bundle/labview/page/ziegler-nichols-autotuning-method.html
    # https://blog.opticontrols.com/archives/477
    # https://web.archive.org/web/20230402131823/http://maulana.lecture.ub.ac.id/files/2014/12/Jurnal-PID_Tunning_Comparison.pdf
    
    if reg_type == "P":
        Kc = tau/(gain*td)
        return Kc,0,0
    elif reg_type == "PI":
        Kc = 0.9 * tau / (gain*td)
        Ti = 3.33 * td
        return Kc, Kc/Ti,0
    
    elif reg_type == "fast PID":
        Kc = 1.2 * tau / (gain/td)
        Ti = 2 * td
        Td = 0.5 * td
        return Kc, Kc/Ti, Kc*Td
    
    elif reg_type == "PID":
        Kc = 0.53* tau/ (gain * td)
        Ti = 4*td
        Td = 0.8 *tau
        return Kc, Kc/Ti, Kc*Td
    
    elif reg_type == "slow PID":
        Kc = 0.32* tau/ (gain * td)
        Ti = 4*td
        Td = 0.8 *tau
            
        return Kc, Kc/Ti, Kc*Td
    
    else:
        ValueError('wrong PID type')
        
def cohen_coon_rule(gain, time_const, delay, reg_type = 'PID'):
    # https://www.ni.com/docs/en-US/bundle/labview/page/cohen-coon-autotuning-method.html
    # https://blog.opticontrols.com/wp-content/uploads/2011/03/StepTest.pn
    G = gain
    tau = time_const
    td = delay
    
    if reg_type == "P":
        Kp = tau/(G*td) * (1+td/(3*tau))
        return Kp, 0, 0
    
    elif reg_type == "PI":
        
        Kp = tau/(G*td) * (0.9 + td/(12*tau))
        Ti = td*(30+3*(td/tau))/(9+20*(td/tau))
        return Kp, Kp/Ti,0
    elif reg_type == "PID":
        Kp = tau/(G*td) * (4/3 + td/(4*tau))
        Ti = td*(32+6*(td/tau))/(13+8*(td/tau))
        Td = 4*td / (11+2*(td/tau))
        return Kp, Kp/Ti,Kp*Td
    else:
        ValueError('wrong PID type')
            
def chien_hrones_reswick_rule(gain, time_const, delay, reg_type = 'PID'):
    # https://www.ni.com/docs/en-US/bundle/labview/page/chien-hrones-reswick-autotuning-method.html
    # https://taketake2.com/K5_en.html
    
    K = gain
    T = time_const
    L = delay
    
    if reg_type == 'P':
        return 0.3 / (K*L), 0,0 
    
    elif reg_type == 'PI':
        return 0.35 *T/ (K*L), 0.29/(K*L),0
    
    elif reg_type == 'PID':
        return 0.6 *T/ (K*L), 0.6/(K*L),0.3*T/K
    
    elif reg_type == 'overshoot P':
        return 0.7 / (K*L), 0,0 
    
    elif reg_type == 'overshoot PI':
        return 0.6 *T/ (K*L), 0.6/(K*L),0    
        
    elif reg_type == 'overshoot PID':
        return 0.95 *T/ (K*L), 0.7/(K*L), 0.44*T/K
    
    else:
        ValueError('wrong PID type')

CALIB_FILE = "/golem/shm_golem/ActualSession/SessionLogBook/WG_PID_calibration_table4H2"
## PID controller parameters
def get_PID_params(requested_pressure):
    calib_table = pd.read_csv(CALIB_FILE, delimiter='\t')
    
    # should be sorted already, but...
    calib_table.sort_values('pressure', inplace=True)
    
    calib_table['dist_to_request_press'] = (calib_table.pressure - requested_pressure).abs()
    
    idx_cal_row  = calib_table.dist_to_request_press.idxmin()
    calib_row = calib_table.loc[idx_cal_row]
    
    p,i,d = ziegler_nichols_method_close_loop(calib_row.K_u,calib_row.T_u)
    return p,i,d

    ## testing weighted paratemters
    #two_smallest_values = calib_table.dist_to_request_press.nsmallest(2)
    #models_overlap = 1.5 # mPa
    #if np.any(two_smallest_values < )



## ActualChamberPressurePa path
ActualChamberPressuremPa = '$SHML/ActualChamberPressuremPa'

## piezo voltage limit [V]
piezo_voltage_limit = 30

## controller update interval [s]
sample_time  = 1.

#LOGFILE="PressurePIDLog"
LOGFILE="/tmp/PressurePIDLog"

# Try to get ActualChamberPressurePa location from ENV
try:
    ActualChamberPressuremPa = Template(ActualChamberPressuremPa)
    ActualChamberPressuremPa = ActualChamberPressuremPa.substitute(os.environ)
except (ValueError, KeyError):
    # or guess the location
    ActualChamberPressuremPa = "/golem/shm_golem/ActualSession/SessionLogBook/ActualChamberPressuremPa"
    if not os.path.exists(ActualChamberPressuremPa):
        #pass
        raise SystemExit('Cannot find ActualChamberPressurePa in %s' % ActualChamberPressuremPa)

# Handle Ctrl^C stop signal
run = True
def handler_stop_signals(signum, frame):
    global run
    run = False

signal.signal(signal.SIGINT, handler_stop_signals)
signal.signal(signal.SIGTERM, handler_stop_signals)


def set_piezo_valve_voltage(voltage: float):
    '''Set voltage on Piezo valve using bash scipt'''
    bash_path = 'WorkingGas.sh'
    function  = 'SetVoltage@GasValveTo'
    params    = '%.2f' % voltage
    run_bash_function(bash_path, function, params)
    
#def set_piezo_valve_voltage(voltage: float):
#    '''Set voltage on Piezo valve using bash scipt'''
#    bash_path = None
#    function  = './Dirigent.sh  -HpV'
#    params    = '%.2f' % voltage
#    run_bash_function(bash_path, function, params)
    
def run_bash_function(library_path, function_name, params):
    if library_path is None:
        params = shlex.split('"%s %s"' % (function_name, params))
    else:
        params = shlex.split('"source %s; %s %s"' % (library_path, function_name, params))
    cmdline = ['bash', '-c'] + params
    subprocess.Popen(cmdline,stdout=subprocess.DEVNULL,
                             stderr=subprocess.STDOUT)
    
    
def linearize_piezo_valve(x):
    if x < 1:
        return 0
    return np.log(x)
    #return 100*np.log(x)/np.log(100)

def linearize_piezo_valve_inversion(x):
    return np.exp(x)
    #return np.exp(np.log(100)*x/100)



last_read_pressure = 0.
def get_current_pressure() -> float:
    '''Current pressure in mPa'''
    global last_read_pressure
    
    with open(ActualChamberPressuremPa, "r") as f:
        ## sometimes we'll read empty string (bash is writing to file)
        ## of so serve last measeurement
        try:
            last_read_pressure = float(f.readline())
            return last_read_pressure
        except ValueError:
            return last_read_pressure
 
def log(pressure,control_voltage,pressure_request, s, p, i, d):    
#    print(f"{pressure}mPa, requested {pressure_request:1.3f}mPa, setpoint {s:1.3f},{control_voltage:2.2f}V, P:{p}, I:{i}, D:{d}")
    
    write_header = not os.path.isfile(LOGFILE)
    with open(LOGFILE, 'a+') as f:
        if write_header:
            f.write("time\tpressure\tpiezo_voltage\tP\tI\tD\n")
            
        f.write(f"{time.strftime('%H:%M:%S')}\t{pressure}\t{control_voltage:2.2f}\t{p:.4e}\t{i:.4e}\t{d:.4e}\n")

def kickoff(kickoff_voltage, kickoff_time, relax_voltage):
    set_piezo_valve_voltage(kickoff_voltage)
    time.sleep(kickoff_time)
    set_piezo_valve_voltage(relax_voltage)

## Add new functionality to simple PID
class AdvPID(PID):
    def __init__(
            self,
            Kp=1.0,
            Ki=0.0,
            Kd=0.0,
            Tf=0.0,
            setpoint=0,
            sample_time=0.01,
            output_limits=(None, None),
            auto_mode=True,
            proportional_on_measurement=False,
            differential_on_measurement=True,
            error_map=None,
            time_fn=None,
            starting_output=0.0,
            setpoint_stable_limit = None,
            setpoint_stable_time = 1.0,
        ):
        
        PID.__init__(
            self,
            Kp,
            Ki,
            Kd,
            setpoint=setpoint,
            sample_time=sample_time,
            output_limits=output_limits,
            auto_mode=auto_mode,
            proportional_on_measurement=False,
            differential_on_measurement=True,
            error_map=error_map,
            time_fn=time_fn,
            starting_output=starting_output,
            )
        self.Tf = Tf
        
        self.setpoint_stable_limit = setpoint_stable_limit
        self.setpoint_stable_time = setpoint_stable_time
        self.setpoint_reached = False
    
    def __call__(self, input_, dt=None):
        e0 = self._last_error if self._last_error is not None else 0.
        
        now = self.time_fn()
        if dt is None:
            dt = now - self._last_time if (now - self._last_time) else 1e-16
        elif dt <= 0:
            raise ValueError('dt has negative value {}, must be positive'.format(dt))
        
        super().__call__(input_, dt)
        
        e = self._last_error
        ## stolen from https://github.com/eadali/advanced-pid/blob/main/advanced_pid/pid.py
        # Calcuate derivative term
        d = 0.0
        if self.Kd != 0.0 and self.Tf is None:
            # no filtering
            # do nothing alteady computed in simplePID
            pass
        elif self.Kd != 0.0 and self.Tf > 0.0:
            Kn = 1.0 / self.Tf
            x = -Kn * self.Kd * e0
            x = exp(-Kn*dt) * x - Kn * (1.0 - exp(-Kn*dt)) * self.Kd * e
            d = x + Kn * self.Kd * e
            e = -(self.Tf/self.Kd) * x
        
            self._derivative = d
        
        # Compute final output
        output = self._proportional + self._integral + self._derivative
        output = _clamp(output, self.output_limits)

        # Keep track of state
        self._last_output = output
        self._last_error = e
        
        
        ## test stablity
        ## from https://github.com/ThunderTecke/PID_Py/blob/main/PID_Py/PID.py
        # ===== Setpoint reached =====
        if (self.setpoint_stable_limit is not None):
            if abs(self._last_error ) < self.setpoint_stable_limit:
                self._setpoint_value_curr_stable_time += dt
            else:
                self._setpoint_value_curr_stable_time = 0.0
            
            self.setpoint_reached = self._setpoint_value_curr_stable_time > self.setpoint_stable_time
        else:
            self._setpoint_value_curr_stable_time = 0.0
            self.setpoint_reached = False
        
        return output
        
    ## do not change proportional_on_measurement, differential_on_measurement
    @property
    def proportional_on_measurement(self):
        return False
    
    @property
    def differential_on_measurement(self):
        return True
    
    @proportional_on_measurement.setter
    def proportional_on_measurement(self, _):
        pass
    
    @differential_on_measurement.setter
    def differential_on_measurement(self, _):
        pass
    

def main():    
    pressure_request = float(sys.argv[1]) if len(sys.argv) > 1 else 0.0
    
    
    Kp, Ki, Kd = get_PID_params(pressure_request)
    print(f"{Kp=}\t{Ki=}\t{Kd=}")
    pid = AdvPID(Kp, Ki, Kd, Tf = None,
              setpoint=pressure_request,
              output_limits = (linearize_piezo_valve_inversion(0), linearize_piezo_valve_inversion(piezo_voltage_limit)),
              )

    while run:
        pressure = get_current_pressure()
        pid_output = pid(pressure)
        
        ## piezoelectric valve has exponential response to voltage
        ## whreras PID controllers are used for linear systems
        ## -> apply linearizing function to pid_output
        control_voltage = linearize_piezo_valve(pid_output)
        set_piezo_valve_voltage(control_voltage)

        
        p,i,d = pid.components
        
        log(pressure,control_voltage,pressure_request,pid.setpoint, p, i, d)

        time.sleep(sample_time)
        
    # set voltage to zero on exit
    set_piezo_valve_voltage(0.)

if __name__ == "__main__":
    main()
