#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Main script of the plasma position algorithm. In this script are data prepared 
# (removed false signals and integration drifts) and the resultes from Plasmposition.py are finaly plotted
#
# Autor: Tomas Odstrcil


import matplotlib 
matplotlib.rcParams['backend'] = 'Agg'
matplotlib.rc('font',  size='10')
matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!


import time
cas = time.time()



from numpy import *
import os
import sys
from pygolem_lite.modules import list2array,deconvolveExp,save_adv,saveconst
from scipy.signal import  fftconvolve,gaussian
from scipy.interpolate import interpolate

from pygolem_lite import Shot
print 'including time: ',  time.time()-cas




#constants
a = 0.085		#[m]
R_0 = 0.4		#[m]
shot = Shot()['shotno']
if shot > 12079:
    calib =-array([  261,  261,  -261, 261]) #T/(V*s)
else:
    calib =-array([  261,  261,  -261, -261]) #T/(V*s)



  
def LoadData():
    Data = Shot()
    t0 = time.time()

    Bt = array(Data['toroidal_field'], copy = False).T
    Uloop = array(Data['loop_voltage'], copy = False).T
    Ip = array(Data['plasma_current'], copy = False).T

    t0 = time.time()
    gd = Shot().get_data
    das, m1  = gd('any', 'mirnov_1', return_channel = True)
    das, m5  = gd('any', 'mirnov_5', return_channel = True)
    das, m9  = gd('any', 'mirnov_9', return_channel = True)
    das, m13 = gd('any', 'mirnov_13', return_channel = True)
    t1 = time.time()

    Papouch = list2array( Data[das, [m1, m5, m9, m13]] ) #FIXME odkazuje to přímo na kanál, mělo by to odkazovat na mirnonovy cívky


    
    CD_trigger = Data['Tcd']
    BD_trigger = Data['Tbd']
    Bt_trigger = Data['Tb']
    CD_voltage = Data['Ucd']
    BD_voltage = Data['Ubd']
    Bt_voltage = Data['Ub']
    plasma_start = Data['plasma_start']
    plasma_end = Data['plasma_end']
    plasma = Data['plasma']
    if CD_voltage == 0:
	CD_trigger = nan
    if BD_voltage == 0:
	BD_trigger = nan
    if Bt_voltage == 0:
	Bt_trigger = nan


    shot = Shot()['shotno']
    return Bt,Uloop,Ip,Papouch,CD_trigger,BD_trigger, Bt_trigger,plasma_end,plasma_start, plasma,shot


def Resample(tvec,vec,t_min,t_max,n):
    if n == 1:
        ind = (tvec>=t_min)&(tvec<=t_max)
        return tvec[ind],vec[ind]
    
    std = (t_max-t_min)/(tvec[1]-tvec[0])/n/2
    gauss_win = gaussian(std*3,std)
    gauss_win/= sum(gauss_win)
    vec_smooth = fftconvolve(vec,gauss_win, mode = 'same')
    tvec_new = linspace(t_min,t_max,n)
    vec_new = interp(tvec_new,tvec, vec_smooth, left=0, right=None)

    return tvec_new,vec_new




def CorrectRCcircuit(tvec,sig, tau):
    
    
    N = len(sig)
    dt = (tvec[-1]-tvec[0])/N
    sig-= mean(sig)

    f = fft.fftfreq(N, d=dt)  
    fsig = fft.fft(sig)

    fsig[0] = 0   #=> mean(x) = 0

    
    
    ifac  = 1-exp(-2*pi*1j*f*dt)  #invert integ factor

    
    q = exp(-(1./tau+2*pi*1j*f)*dt)
    fexp = (1-q**(N/2-1))/(1-q)
    fexp/= fexp[0]  #cca dt/tau
    
    
    
    
    filter = ones_like(fsig)

    
    filter/= fexp#deconvolution
    
    filter[1:]/= ifac[1:]  #integration
    
    integ_sig = real(fft.ifft(fsig*filter))


    
    integ_sig -= integ_sig[0]+sig[0]
    integ_sig*= dt
    
    
    
    
    ##O(n) calculation vfor realtime aplication 
    #integ = zeros_like(sig)
    #xn = sig[0]
    #for j in xrange(step,N,step):
	#xn*= exp(-dt*step/tau)
	#deriv  = sig[j]-xn
	#integ[j] = deriv+integ[j-step]
	#xn = sig[j]
    
    #integ*= dt/tau
    
    
    return integ_sig
    
    
    
def get_position():
    from PlasmaPosition import CalcPlasmaPosition,RemoveDriftsAuto

    cas = time.time()

    Bt,Uloop,Ip,signal,CD_trigger,BD_trigger, Bt_trigger,plasma_end,plasma_start, plasma,shotno  = LoadData()

    tvec = signal[:,0]
    signal = signal[:,1:]
    t0 = tvec[0]
    dt = tvec[1]-tvec[0]
    
    #BUG 1. chanell has RC integrator
    if shotno >= 11381 and shotno <= 11837:
	signal[:,0] = CorrectRCcircuit(tvec,signal[:,0], 0.1e-3)
	
	signal[:,1:] = cumsum(signal[:,1:],axis = 0, out=signal[:,1:])*dt #možná by šla udělat nějaká lepší integrace
    else:    
	signal = cumsum(signal,axis = 0, out=signal)*dt #možná by šla udělat nějaká lepší integrace
    
    print 'load data: ',  time.time()-cas

    reduce=500
    if mean(abs(Uloop[:,1])) == 0 or mean(abs(Bt[:,1])) == 0 :  #selhání diagnostik
	print 'all diagnostics failured '
	return None


    #correct effects of the chamber currents  (projection constants)
    Uproj = array([-0.5e-07, -1.0e-07, -0.2e-07,-0.8e-07])
    Uloop_ds = interp(tvec, Uloop[:,0],Uloop[:,1], left=0, right=None)
    signal-= outer(Uloop_ds,Uproj)
   
    Bt_crosstalk = True
    Uloop_crosstalk = False

    if abs(median(Bt[:,1])) < 0.001:
	Bt_crosstalk = False
    if abs(median(Uloop[:,1])) < 0.05:
	Uloop_crosstalk = False



    #prepare signals
    E_trigger = nanmin([CD_trigger, BD_trigger])

    #remove crosstalks+drifts
    signal,err, AutoRemoveGraph = RemoveDriftsAuto(vstack((tvec,signal.T)).T, Bt,Uloop,plasma_start,plasma_end,Bt_trigger,E_trigger)

    #now is a plasma presence is necassary
    if not plasma:
	return [AutoRemoveGraph,Bt_trigger]




    #create detectors
	  
    rho  = empty((4,1))
    zeta = empty((4,1))
    
    r_det = 0.093
    phi_det = linspace(0,1.5*pi,4 )
    
    rho[:,0]  = r_det*cos(phi_det)+R_0
    zeta[:,0] = r_det*sin(phi_det)
 
    detectorPos = hstack((rho,zeta))
    

    tvec_ds, Ip_ds =  Resample(Ip[:,0],Ip[:,1],plasma_start,plasma_end,reduce)
    detectorSignal = signal[:,1:]
    tvec = signal[:,0]

    detectorSignal_ds = list()
    for i in range(size(detectorSignal,1)):
	tvec_ds,data_ds =  Resample(tvec,detectorSignal[:,i],plasma_start,plasma_end,reduce)
	detectorSignal_ds.append(data_ds)
    detectorSignal_ds = array(detectorSignal_ds, copy = False).T
    
    #print detectorSignal_ds.shape
    #exit()
 



    #   estimated errors
    #TODO here can by applied information about lower precision of some detector

    IpDriftError =7
    detectorDriftError = ones(4)* 1.5e-7
    
    
    detectorDriftError*= abs(array(calib))
    detectorSignal_ds*= array(calib)


    pos_tvec, position, radius, residuum, retrofit, data,chi2 = \
            CalcPlasmaPosition(single(tvec_ds),single(detectorPos), single(detectorSignal_ds),  
                    single(detectorDriftError), single(Ip_ds), single(IpDriftError))
    
    save_adv('results/R_position', pos_tvec, position[:,0])
    save_adv('results/Z_position', pos_tvec, position[:,1])
    save_adv('results/plasma_radius', pos_tvec, radius)
    savetxt('results/R_position.txt', vstack((pos_tvec, position[:,0])).T, fmt='%.4e %.3e' )
    savetxt('results/Z_position.txt', vstack((pos_tvec, position[:,1])).T, fmt='%.4e %.2e' )
    savetxt('results/plasma_radius', vstack((pos_tvec, radius)).T, fmt='%.4e %.2e'  )
    #savetxt('results/residuum', mean(residuum), fmt='%.1e')
    saveconst('results/residuum', median(residuum/ data[:,0]**2*1e6))
    
    
    
    #f_Ip = interpolate.interp1d( Ip[:,0],Ip[:,1] , bounds_error = False, fill_value = 0,copy = False)
    Ip = interp(tvec,  Ip[:,0],Ip[:,1])
    
    return pos_tvec, position,  residuum, retrofit, single(data),single(tvec),single(Ip), IpDriftError, single(detectorSignal), detectorDriftError, plasma_start, plasma_end,Bt_trigger,chi2,AutoRemoveGraph

def graphs(  inputs , filetype = 'png' ):

    if len(inputs) == 2: #vacuum discharge
        [AutoRemoveGraph,Bt_trigger] = inputs
    else:
        [pos_tvec,position,residuum,retrofit,data,tvec,Ip,IpDriftError, detectorSignal,detectorDriftError,plasma_start,plasma_end,Bt_trigger,chi2,AutoRemoveGraph] = inputs 

        #===============  Plot results======================
    import matplotlib 
    matplotlib.rcParams['backend'] = 'Agg'
    matplotlib.rc('font',  size='10')
    matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!
    import matplotlib.pyplot as plt
        


    class MyFormatter(plt.ScalarFormatter): 
        def __call__(self, x, pos=None): 
            if pos==0: 
                return '' 
            else: return plt.ScalarFormatter.__call__(self, x, pos)  
        
    
     
    #plot graph from auto removing   of crosstalks  
    t0 = time.time()

    (frames, vlines,hlines,rectangles) = AutoRemoveGraph

    fig = plt.figure(figsize=(10, 10), dpi=80, facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace=0, wspace = 0)
    
    for i,frame in enumerate(frames):
        (t, curves, ytext,text) = frame
        
        ax = fig.add_subplot(len(frames),1,i+1)
        ax.xaxis.set_major_formatter( plt.NullFormatter() )
        ax.yaxis.set_major_formatter( MyFormatter() )
        ax.text(.05,.1,text,horizontalalignment='left',verticalalignment='bottom',transform=ax.transAxes,backgroundcolor = 'w')
        
        for curve in curves:
            (y,label, linestyle) = curve
            ax.plot(t,y, linestyle , label = label)
        
        for hline in hlines:
            (y0,linestyle) =  hline
            ax.axhline(y = y0,linestyle = linestyle)
        for vline in vlines:
            (x0,linestyle) =  vline
            ax.axvline(x = x0,linestyle = linestyle)
        ax.set_ylabel(ytext)

        ax.axis('tight')
        ax.set_xlim(Bt_trigger*1000, None)
        (y_min,y_max) = ax.get_ylim()
    

        for (x_min,x_max) in rectangles:
            r = plt.Rectangle((x_min, y_min),x_max-x_min,y_max-y_min)
            r.set_clip_box(ax.bbox)
            r.set_alpha(0.05)
            ax.add_artist(r)
            
    handles, labels = ax.get_legend_handles_labels()
    handles.append(r)
    labels.append('Minimized interval')
    ax.xaxis.set_major_formatter( plt.ScalarFormatter() )

    ax.set_xlabel('t [ms]')
    leg = ax.legend(handles, labels,loc='best', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    fig.savefig('./graphs/signal_correction.'+filetype,bbox_inches='tight')
    plt.close()

    print 'plot graph from auto removing   of crosstalks   ',time.time()-t0
        

    if len(inputs) == 2:  #no plasma
        return 

    
    n_det = size(detectorSignal,1)


    t = time.time()   
    
    
    
    #plot position
    fig = plt.figure(figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    fig.subplots_adjust( bottom=0.2)
    ax=fig.add_subplot(111)
    minorLocator   = plt.MultipleLocator(0.5)
    ax.xaxis.set_minor_locator(minorLocator)
    Rplot,= ax.plot(pos_tvec*1e3, (position[:,0]-0.4)*100,'b', label='R-R$_0$',lw=0.5)
    Zplot,= ax.plot(pos_tvec*1e3, position[:,1]*100,'r', label = 'Z', linewidth=0.5)
    upper_lim=ax.axhline(y = a*100, ls = '--' ,label = 'limiter')
    lower_lim=ax.axhline(y = -a*100, ls = '--' )
    leg = ax.legend(loc='lower left', fancybox=True)
    leg.get_frame().set_alpha(0.7)
    ax.set_xlim(pos_tvec[0]*1e3, pos_tvec[-1]*1e3)
    ax.set_ylim(-10,10)
    ax.set_ylabel('R,Z [cm]')
    ax.set_xlabel('t [ms]')
    fig.savefig('graphs/plasma_position.'+filetype,bbox_inches='tight')
    print 'plot position ',time.time()-t

    
    
    # plot plasma radius
    
    t = time.time()
    
    radius = a-hypot(position[:,0]-0.4, position[:,1])
    Rplot.set_ydata(radius*100)
    Rplot.set_label('Plasma radius')
    ax.set_ylim(0,10)
    Zplot.set_visible(False)
    lower_lim.set_visible(False)
    ax.set_ylabel('r [cm]')
    leg = ax.legend((Rplot,),(Rplot.get_label(),),loc='upper right', fancybox=True)
    leg.get_frame().set_alpha(0.7)
    fig.savefig('graphs/plasma_radius.'+filetype,bbox_inches='tight')
    print 'plasma radius ',time.time()-t

    t = time.time()

    #plot residuum
    Rplot.set_visible(False)
    Zplot.set_visible(True)
    Zplot.set_ydata(residuum/data[:,0]**2*1e6)
    Zplot.set_label('residuum of the fit/$I_p^2$')
    ax.set_yscale('log')
    upper_lim.set_ydata([1e2,]*2)
    upper_lim.set_linestyle('-')
    lower_lim.set_ydata([10,]*2)
    lower_lim.set_linestyle('-')
    lower_lim.set_visible(True)

    txt1=ax.text(mean(pos_tvec*1e3),3 , 'reliable fit', color='k')
    txt2=ax.text(mean(pos_tvec*1e3),50 , 'considerable fit', color='k')
    txt3=ax.text(mean(pos_tvec*1e3),1e3 , 'poor fit', color='k')


    ax.set_ylim(1,1e4)
    ax.set_ylabel('SSE/$I^2_p$ [kA$^{-2}$]')
    leg = ax.legend((Zplot,),(Zplot.get_label(),),loc='upper left', fancybox=True)
    leg.get_frame().set_alpha(0.7)
  
    fig.savefig('graphs/residuum.'+filetype, bbox_inches='tight')
    print 'plot residuum ',time.time()-t



    t = time.time()

    #plot position polar
    rho_plot,phi_plot = Rplot,Zplot
    r = hypot(position[:,0]-0.4,position[:,1])
    phi = arctan2(position[:,1],-(position[:,0]-0.4))
    txt3.set_visible(False)
    txt1.set_visible(False)
    txt2.set_visible(False)
    ax.set_yscale('linear')

    axL = ax
    rho_plot.set_ydata(r*100)
    rho_plot.set_visible(True)
    rho_plot.set_label(r'radial coordinate $\rho$')
    phi_plot.set_ydata(phi/pi*5+5)
    phi_plot.set_label('angular coordinate $\phi$')
    axL.set_ylabel(r'$\rho$ [cm]')
    axL.set_ylim(0,10)
    axL.set_xlabel('t [ms]')
    axL.yaxis.tick_left()
    handlesL, labelsL = axL.get_legend_handles_labels()

    
    axR = fig.add_subplot(111, sharex=axL, frameon=False)
    axR.yaxis.tick_right()
    axR.yaxis.set_label_position("right")

    axR.set_ylabel('$\phi$ [rad]')
    axR.set_ylim(-pi,pi)
    handlesR, labelsR = axR.get_legend_handles_labels()
    handles = hstack((handlesL[:2],handlesR))
    labels  = hstack((labelsL[:2],labelsR))

    leg = axL.legend(handles,labels,loc='lower left', fancybox=True)
    leg.get_frame().set_alpha(0.7)

    fig.savefig('graphs/plasma_position_polar.'+filetype,bbox_inches='tight')
    print 'plot angle position ',time.time()-t

    
    
    t = time.time()
    
    #plot raw data
    fig.clf()
    fig.subplots_adjust( bottom=0.2)

    ax = fig.add_subplot(111)
    ax.plot(tvec[:-50]*1e3,Ip[:-50]/IpDriftError,'-.', label = '$I_p$') 
    detectorSignal*=calib/detectorDriftError
    for i in range(4):
        ax.plot(tvec*1e3,detectorSignal[:,i],label='mc%d'%(i*4+1),lw=0.5)
    ax.axvline(x = plasma_start*1e3, ls= '--')
    ax.axvline(x = plasma_end*1e3, ls = '--')
    ax.axhline(y = 0, ls = '-.')
    ax.set_xlabel('t [ms]')
    ax.set_ylabel('signal/estimated error [-]')
    leg = ax.legend(loc='upper left', fancybox=True)
    leg.get_frame().set_alpha(0.7)
    ax.axis('tight')
    ax.set_xlim(Bt_trigger*1e3, min(1e3*plasma_end+5, 40))
    ax.set_ylim(-10,None)

    fig.savefig('graphs/preprocesed_signal.'+filetype,bbox_inches='tight')
    print 'plot raw data',time.time()-t

    plt.close('all')    
     


    t = time.time() 
  


    #plot retrofit  
    from scipy.stats.mstats  import mquantiles

    fig = plt.figure( figsize=(10, 10), dpi=80, facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace=0, wspace = 0)
    c = ('r', 'g', 'b', 'y')
    y_min_lim = amin(data[:,1:])
    y_max_lim = amax(data[:,1:])

    for i in range(1,n_det+1):
        ax = fig.add_subplot(n_det+1,1,i)
        ax.xaxis.set_major_formatter( plt.NullFormatter() )
        ax.yaxis.set_major_formatter( MyFormatter() )
        ax.plot(pos_tvec*1e3,data[:,i]/data[:,0]*1e6,c[i%4],label='mc%d'%(i*4-3)+'/I$_p$')
        ax.plot(pos_tvec*1e3,retrofit[:,i]/data[:,0]*1e6,c[i%4]+'-.',label='retrofit')
        
        ax.text(.05,.1,r'$\chi^2$ = %2.1f'% chi2[i],horizontalalignment='left',
                    verticalalignment='bottom',transform=ax.transAxes)

        leg = ax.legend(loc='upper left', fancybox=True)
        leg.get_frame().set_alpha(0.5)
        data_max = mquantiles(data[:,i]/data[:,0]*1e6, 0.95)*2

        ax.set_ylim(0, data_max)
        ax.set_ylabel('mc%d'%(i*4-3)+'/$I_p$ [mT/kA]')


    ax = fig.add_subplot(n_det+1,1,n_det+1)
    ax.plot(pos_tvec*1e3,ones(size(data,0)),label = 'I$_p$/I$_p$')
    ax.plot(pos_tvec*1e3,retrofit[:,0]/data[:,0],'--', label = 'retrofit')
    ax.text(.05,.1,'$\\chi^2$ = %2.1f'%chi2[0],horizontalalignment='left',
                verticalalignment='bottom',transform=ax.transAxes)

    leg = ax.legend(loc='upper left', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    ax.set_ylim(0,2)
    ax.set_ylabel('$I_p/I_p$ [-]')
    ax.set_xlabel('t [ms]')
    fig.savefig('./graphs/retrofit.'+filetype,bbox_inches='tight')

    plt.close()


    print 'retrofit  ',time.time()-t

    
    print 'graphs plotted in %g s' % (time.time()-t0)


def main():
    
    for path in ['graphs', 'results','constants' ]:
	if not os.path.exists(path):
	    os.mkdir(path)
	    
    if sys.argv[1] ==  "analysis":
	out = get_position()
	save('out', out)

    if sys.argv[1] ==  "plots":
	out = load('out.npy')
	os.remove('out.npy')
	graphs(out, 'png')
	saveconst('status', 0)
	os.system('convert -resize 150 graphs/plasma_position.png icon.png')



if __name__ == "__main__":
    main()
    	 
