#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Main script of the plasma position algorithm. In this script are data prepared 
# (removed false signals and integration drifts) and the resultes from Plasmposition.py are finaly plotted
#
# Autor: Tomas Odstrcil


import matplotlib 
matplotlib.rcParams['backend'] = 'Agg'
matplotlib.rc('font',  size='10')
matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!


import time
cas = time.time()



from numpy import *
import os
import sys
from pygolem_lite.modules import list2array,deconvolveExp,save_adv,saveconst
from scipy.signal import  fftconvolve,gaussian
from scipy.interpolate import interpolate

from pygolem_lite import Shot
print 'including time: ',  time.time()-cas




#constants
a = 0.085		#[m]
R_0 = 0.4		#[m]
shot = Shot()['shotno']
if shot > 12079:
    calib =-array([  261,  261,  -261, 261]) #T/(V*s)
else:
    calib =-array([  261,  261,  -261, -261]) #T/(V*s)



  
def LoadData():
    Data = Shot()
    t0 = time.time()

    Bt = array(Data['toroidal_field'], copy = False).T
    Uloop = array(Data['loop_voltage'], copy = False).T
    Ip = array(Data['plasma_current'], copy = False).T

    t0 = time.time()
    gd = Shot().get_data
    das, m1  = gd('any', 'mirnov_1', return_channel = True)
    das, m5  = gd('any', 'mirnov_5', return_channel = True)
    das, m9  = gd('any', 'mirnov_9', return_channel = True)
    das, m13 = gd('any', 'mirnov_13', return_channel = True)
    t1 = time.time()

    Papouch = list2array( Data[das, [m1, m5, m9, m13]] ) #FIXME odkazuje to přímo na kanál, mělo by to odkazovat na mirnonovy cívky


    
    CD_trigger = Data['Tcd']
    BD_trigger = Data['Tbd']
    Bt_trigger = Data['Tb']
    CD_voltage = Data['Ucd']
    BD_voltage = Data['Ubd']
    Bt_voltage = Data['Ub']
    plasma_start = Data['plasma_start']
    plasma_end = Data['plasma_end']
    plasma = Data['plasma']
    if CD_voltage == 0:
	CD_trigger = nan
    if BD_voltage == 0:
	BD_trigger = nan
    if Bt_voltage == 0:
	Bt_trigger = nan


    shot = Shot()['shotno']
    return Bt,Uloop,Ip,Papouch,CD_trigger,BD_trigger, Bt_trigger,plasma_end,plasma_start, plasma,shot


def Resample(tvec,vec,t_min,t_max,n):
    
    std = (t_max-t_min)/(tvec[1]-tvec[0])/n/2
    gauss_win = gaussian(std*3,std)
    gauss_win/= sum(gauss_win)
    vec_smooth = fftconvolve(vec,gauss_win, mode = 'same')
    tvec_new = linspace(t_min,t_max,n)
    vec_new = interp(tvec_new,tvec, vec_smooth, left=0, right=None)

    #f = interpolate.interp1d(tvec,vec_smooth,bounds_error = False, fill_value = 0, copy = False)
    #vec_new = f(tvec_new)
    
    #plot(tvec_new,vec_new, '--')
    #plot(tvec,vec_smooth)
    #plot(tvec,vec)
    #show()
    return tvec_new,vec_new




def CorrectRCcircuit(tvec,sig, tau):
    
    
    N = len(sig)
    dt = (tvec[-1]-tvec[0])/N
    sig-= mean(sig)

    f = fft.fftfreq(N, d=dt)  
    fsig = fft.fft(sig)

    fsig[0] = 0   #=> mean(x) = 0

    
    
    ifac  = 1-exp(-2*pi*1j*f*dt)  #invert integ factor

    
    q = exp(-(1./tau+2*pi*1j*f)*dt)
    fexp = (1-q**(N/2-1))/(1-q)
    fexp/= fexp[0]  #cca dt/tau
    
    
    
    
    filter = ones_like(fsig)

    
    filter/= fexp#deconvolution
    
    filter[1:]/= ifac[1:]  #integration
    
    integ_sig = real(fft.ifft(fsig*filter))


    
    integ_sig -= integ_sig[0]+sig[0]
    integ_sig*= dt
    
    
    
    
    ##O(n) calculation vfor realtime aplication 
    #integ = zeros_like(sig)
    #xn = sig[0]
    #for j in xrange(step,N,step):
	#xn*= exp(-dt*step/tau)
	#deriv  = sig[j]-xn
	#integ[j] = deriv+integ[j-step]
	#xn = sig[j]
    
    #integ*= dt/tau
    
    
    return integ_sig
    
    
    
def get_position():
    from PlasmaPosition import CalcPlasmaPosition,RemoveDriftsAuto

    cas = time.time()

    Bt,Uloop,Ip,signal,CD_trigger,BD_trigger, Bt_trigger,plasma_end,plasma_start, plasma,shotno  = LoadData()

    tvec = signal[:,0]
    signal = signal[:,1:]
    t0 = tvec[0]
    dt = tvec[1]-tvec[0]
    
    #BUG 1. chanell has RC integrator
    if shotno >= 11381 and shotno <= 11837:
	signal[:,0] = CorrectRCcircuit(tvec,signal[:,0], 0.1e-3)
	
	signal[:,1:] = cumsum(signal[:,1:],axis = 0, out=signal[:,1:])*dt #možná by šla udělat nějaká lepší integrace
    else:    
	signal = cumsum(signal,axis = 0, out=signal)*dt #možná by šla udělat nějaká lepší integrace
    
    print 'load data: ',  time.time()-cas

    reduce = len(signal)/400   # calculate in 400 points
    
    if mean(abs(Uloop[:,1])) == 0 or mean(abs(Bt[:,1])) == 0 :  #selhání diagnostik
	print 'all diagnostics failured '
	return None


    #correct effects of the chamber currents  (projection constants)
    Uproj = array([-0.5e-07, -1.0e-07, -0.2e-07,-0.8e-07])
    U_loop_ds = interp(tvec, Uloop[:,0],Uloop[:,1], left=0, right=None)
    #f_uloop = interpolate.interp1d(Uloop[:,0],Uloop[:,1],bounds_error = False, fill_value = 0, copy = False)
    #U_loop_ds = f_uloop(tvec)  
    for i in range(4):
	signal[:,i]-= Uproj[i]*U_loop_ds


    Bt_crosstalk = True
    Uloop_crosstalk = False

    if abs(median(Bt[:,1])) < 0.001:
	Bt_crosstalk = False
    if abs(median(Uloop[:,1])) < 0.05:
	Uloop_crosstalk = False



    #prepare signals
    E_trigger = nanmin([CD_trigger, BD_trigger])

    #remove crosstalks+drifts
    signal,err, AutoRemoveGraph = RemoveDriftsAuto(vstack((tvec,signal.T)).T, Bt,Uloop,plasma_start,plasma_end,Bt_trigger,E_trigger)

    #now is a plasma presence is necassary
    if not plasma:
	return [AutoRemoveGraph,Bt_trigger]




    #create detectors
	  
    rho  = empty((4,1))
    zeta = empty((4,1))
    
    r_det = 0.093
    phi_det = linspace(0,1.5*pi,4 )
    
    rho[:,0]  = r_det*cos(phi_det)+R_0
    zeta[:,0] = r_det*sin(phi_det)
 
    detectorPos = hstack((rho,zeta))
    

    tvec_ds, Ip_ds =  Resample(Ip[:,0],Ip[:,1],plasma_start,plasma_end,reduce)
    detectorSignal = signal[:,1:]
    tvec = signal[:,0]

    detectorSignal_ds = list()
    for i in range(size(detectorSignal,1)):
	#print shape(tvec), shape(detectorSignal[:,i])
	tvec_ds,data_ds =  Resample(tvec,detectorSignal[:,i],plasma_start,plasma_end,reduce)
	detectorSignal_ds.append(data_ds)
    detectorSignal_ds = array(detectorSignal_ds, copy = False).T
    
    
 



    #   estimated errors
    #TODO here can by applied information about lower precision of some detector

    IpDriftError =7
    detectorDriftError = ones(4)* 1.5e-7
    
    
    detectorDriftError*= abs(array(calib))
    detectorSignal_ds*= array(calib)


    pos_tvec, position, radius, residuum, retrofit, data,chi2 = CalcPlasmaPosition(single(tvec_ds),single(detectorPos), single(detectorSignal_ds), single(detectorDriftError), single(Ip_ds), single(IpDriftError))
    
    save_adv('results/R_position', pos_tvec, position[:,0])
    save_adv('results/Z_position', pos_tvec, position[:,1])
    save_adv('results/plasma_radius', pos_tvec, radius)
    savetxt('results/R_position.txt', vstack((pos_tvec, position[:,0])).T, fmt='%.4e %.3e' )
    savetxt('results/Z_position.txt', vstack((pos_tvec, position[:,1])).T, fmt='%.4e %.2e' )
    savetxt('results/plasma_radius', vstack((pos_tvec, radius)).T, fmt='%.4e %.2e'  )
    #savetxt('results/residuum', mean(residuum), fmt='%.1e')
    saveconst('results/residuum', median(residuum/ data[:,0]**2*1e6))
    
    
    
    #f_Ip = interpolate.interp1d( Ip[:,0],Ip[:,1] , bounds_error = False, fill_value = 0,copy = False)
    Ip = interp(tvec,  Ip[:,0],Ip[:,1])
    
    return pos_tvec, position,  residuum, retrofit, single(data),single(tvec),single(Ip), IpDriftError, single(detectorSignal), detectorDriftError, plasma_start, plasma_end,Bt_trigger,chi2,AutoRemoveGraph


def graphs(  inputs , filetype = 'png' ):

    if len(inputs) == 2: #vacuum discharge
	[AutoRemoveGraph,Bt_trigger] = inputs
    else:
	[pos_tvec,position,residuum,retrofit,data,tvec,Ip,IpDriftError, detectorSignal,detectorDriftError,plasma_start,plasma_end,Bt_trigger,chi2,AutoRemoveGraph] = inputs 

        #===============  Plot results======================
    import matplotlib 
    matplotlib.rcParams['backend'] = 'Agg'
    matplotlib.rc('font',  size='10')
    matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!
    import matplotlib.pyplot as plt
	


    class MyFormatter(plt.ScalarFormatter): 
	def __call__(self, x, pos=None): 
	    if pos==0: 
		return '' 
	    else: return plt.ScalarFormatter.__call__(self, x, pos)  
	
    
     
    #plot graph from auto removing   of crosstalks  

    t1 = time.time() 
    (frames, vlines,hlines,rectangles) = AutoRemoveGraph

    fig = plt.figure(num=None, figsize=(10, 10), dpi=80, facecolor='w', edgecolor='k')
    plt.subplots_adjust(hspace=0, wspace = 0)
    
    for i,frame in enumerate(frames):
	(t, curves, ytext,text) = frame
	
	ax = fig.add_subplot(len(frames),1,i+1)
	ax.xaxis.set_major_formatter( plt.NullFormatter() )
	ax.yaxis.set_major_formatter( MyFormatter() )
	ax.text(.05,.1,text,horizontalalignment='left',verticalalignment='bottom',transform=ax.transAxes,backgroundcolor = 'w')
	
	for curve in curves:
	    (y,label, linestyle) = curve
	    plt.plot(t,y, linestyle , label = label)
	
	for hline in hlines:
	    (y0,linestyle) =  hline
	    plt.axhline(y = y0,linestyle = linestyle)
	for vline in vlines:
	    (x0,linestyle) =  vline
	    plt.axvline(x = x0,linestyle = linestyle)
	plt.ylabel(ytext)

	plt.axis('tight')
	plt.xlim(Bt_trigger*1000, None)
	(y_min,y_max) = plt.ylim()
    

	for (x_min,x_max) in rectangles:
	    r = plt.Rectangle((x_min, y_min),x_max-x_min,y_max-y_min)
	    r.set_clip_box(ax.bbox)
	    r.set_alpha(0.05)
	    ax.add_artist(r)
	    
    handles, labels = ax.get_legend_handles_labels()
    handles.append(r)
    labels.append('Minimized interval')
    ax.xaxis.set_major_formatter( plt.ScalarFormatter() )

    plt.xlabel('t [ms]')
    leg = plt.legend(handles, labels,loc='best', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    plt.savefig('./graphs/signal_correction.'+filetype,bbox_inches='tight')
    plt.close()

    print ' plot graph from auto removing   of crosstalks   ',time.time()-t1
	

    if len(inputs) == 2:  #no plasma
	return 

    
    n_det = size(detectorSignal,1)

    t0 = time.time()

    t = time.time()   
    
    
    
    #plot position
    fig = plt.figure(num=None, figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    plt.plot(pos_tvec*1e3, (position[:,0]-0.4)*100,'b', label = 'R-R$_0$', linewidth=0.5)
    plt.plot(pos_tvec*1e3, position[:,1]*100,'r', label = 'Z', linewidth=0.5)
    plt.axhline(y = a*100, linestyle = '--' ,label = 'limiter')
    plt.axhline(y = -a*100, linestyle = '--' )
    leg = plt.legend(loc='lower left', fancybox=True)
    leg.get_frame().set_alpha(0.7)
    plt.xlim(pos_tvec[0]*1e3, pos_tvec[-1]*1e3)
    plt.ylim(-10,10)
    plt.ylabel('R,Z [cm]')
    plt.xlabel('t [ms]')
    plt.subplots_adjust( bottom=0.2)
    plt.savefig('graphs/plasma_position.'+filetype,bbox_inches='tight')

    fig.set_figwidth(4)
    fig.set_figheight(3)
    fig.dpi = 40
    leg.get_frame().set_alpha(0)
    plt.savefig('icon.png', bbox_inches='tight', dpi= 40)
    
    
    
    plt.close()
    
        #plot position polar
    
    r = hypot(position[:,0]-0.4,position[:,1])
    phi = arctan2(position[:,1],-(position[:,0]-0.4))
    
    
    fig = plt.figure(num=None, figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    axL = fig.add_subplot(1,1,1)
    plt.plot(pos_tvec*1e3, r*100,'b', label='radial coordinate $\\rho$', linewidth=0.5)
    axL.set_ylabel('$\\rho$ [cm]')
    axL.set_ylim(0,10)
    axL.set_xlim(pos_tvec[0]*1e3, pos_tvec[-1]*1e3)
    axL.set_xlabel('t [ms]')
    axL.yaxis.tick_left()
    handlesL, labelsL = axL.get_legend_handles_labels()

    
    axR = fig.add_subplot(1,1,1, sharex=axL, frameon=False)
    axR.yaxis.tick_right()
    axR.yaxis.set_label_position("right")
    
    plt.plot(pos_tvec*1e3, phi,'r', label='angular coordinate $\phi$', linewidth=0.5)
    axR.set_ylabel('$\phi$ [rad]')
    axR.set_xlim(pos_tvec[0]*1e3, pos_tvec[-1]*1e3)
    axR.set_ylim(-pi,pi)
    handlesR, labelsR = axR.get_legend_handles_labels()
    handles = hstack((handlesL,handlesR))
    labels = hstack((labelsL,labelsR))

    leg = axL.legend(handles,labels,loc='lower left', fancybox=True)
    
    #leg = plt.legend(loc='lower left', fancybox=True)
    leg.get_frame().set_alpha(0.7)
    #plt.ylim(-10,10)
    #plt.xlabel('t [ms]')
    plt.subplots_adjust( bottom=0.2)
    fig.savefig('graphs/plasma_position_polar.'+filetype,bbox_inches='tight')

    #fig.set_figwidth(4)
    #fig.set_figheight(3)
    #fig.dpi = 40
    #leg.get_frame().set_alpha(0)
    #plt.savefig('icon.png', bbox_inches='tight', dpi= 40)
    
    
    
    plt.close()
    
    print 'plot position ',time.time()-t
    t = time.time()
    
    
    #plot residuum
    fig = plt.figure(num=None, figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    plt.semilogy(pos_tvec*1e3, residuum/ data[:,0]**2*1e6, 'r',  label = 'residuum of the fit/$I_p^2$')
    plt.axhline(y = 10)
    plt.axhline(y = 1e2)
    plt.text(mean(pos_tvec*1e3),3 , r'reliable fit', color='k')
    plt.text(mean(pos_tvec*1e3),50 , r'considerable fit', color='k')
    plt.text(mean(pos_tvec*1e3),1e3 , r'poor fit', color='k')

    plt.xlim(pos_tvec[0]*1e3, pos_tvec[-1]*1e3)

    plt.ylim(1,1e4)
    plt.xlabel('t [ms]')
    plt.ylabel('SSE/$I^2_p$ [kA$^{-2}$]')
    plt.subplots_adjust( bottom=0.2)
    leg = plt.legend(loc='upper left', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    plt.savefig('graphs/residuum.'+filetype, bbox_inches='tight')

    plt.close()
    
    print 'plot residuum ',time.time()-t
    t = time.time()

    
    
    cas = time.time()

    
    
    # plot plasma radius
    
    t = time.time()
    # plot plasma radius
    radius = a-sqrt((position[:,0]-0.4)**2+position[:,1]**2)
    fig = plt.figure(num=None, figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    plt.plot(pos_tvec*1e3, radius*100,'b', label = 'Plasma radius', linewidth=0.5)
    plt.xlim(pos_tvec[0]*1e3, pos_tvec[-1]*1e3)
    plt.ylim(0,10)
    plt.axhline(y = a*100, linestyle = '--' ,label = 'limiter')
    plt.ylabel('r [cm]')
    plt.xlabel('t [ms]')

    leg = plt.legend(loc='upper right', fancybox=True)
    leg.get_frame().set_alpha(0.7)
    plt.subplots_adjust( bottom=0.2)
    plt.savefig('graphs/plasma_radius.'+filetype,bbox_inches='tight')

    plt.close()
    
    
    print ' plasma radius ',time.time()-t
    t = time.time()
    
    #plot raw data
    fig = plt.figure(num=None, figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    plt.plot(tvec[:-50]*1000,Ip[:-50]/IpDriftError,'-.', label = '$I_p$') 
    for i in range(4):
	plt.plot(tvec*1e3,detectorSignal[:,i]/detectorDriftError[i]*calib[i], label = 'mc'+str(i*4+1), linewidth=0.5)
    plt.axvline(x = plasma_start*1000, linestyle= '--')
    plt.axvline(x = plasma_end*1000, linestyle = '--')
    plt.axhline(y = 0, linestyle = '-.')
    plt.xlabel('t [ms]')
    plt.ylabel('signal/estimated error [-]')
    leg = plt.legend(loc='upper left', fancybox=True)
    leg.get_frame().set_alpha(0.7)
    plt.subplots_adjust( bottom=0.2)
    plt.axis('tight')
    plt.xlim(Bt_trigger*1000, min(1000*plasma_end+5, 40))
    plt.ylim(-10,None)

    plt.savefig('graphs/preprocesed_signal.'+filetype,bbox_inches='tight')
    
    print ' plot raw data',time.time()-t

    plt.close('all')    
     


    t = time.time()
  


    ## plot plasma profile R,Z
    #time_horizont_profile = empty((50,size(data,0)))
    #time_vertical_profile = empty((50,size(data,0)))
    #x = linspace(-a,a,50)
    #for i in range(size(data,0)):
	#circle = 1-((x-position[i,0]+0.4)/radius[i])**2
	#circle*= (1+sign(circle))/2
	#time_horizont_profile[:,i] = sqrt(circle)
    #for i in range(size(data,0)):
	#circle = 1-((x-position[i,1])/radius[i])**2
	#circle*= (1+sign(circle))/2
	#time_vertical_profile[:,i] = sqrt(circle)

    #fig = figure(num=None, figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    #pcolor( pos_tvec*1e3, (x+0.4)*100, time_horizont_profile)
    #xlabel('t [ms]')
    #ylabel( 'R [cm]')
    #axis('tight')
    #subplots_adjust( bottom=0.2)
    #savefig('graphs/plasma_profile_r.png')
    #close()


    #fig = figure(num=None, figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    #pcolor( pos_tvec*1e3, x*100, time_vertical_profile)
    #xlabel('t [ms]')
    #ylabel( 'Z [cm]')
    #axis('tight')
    #subplots_adjust( bottom=0.2)
    #savefig('graphs/plasma_profile_z.png')
    #close()


    #plot retrofit  
    from scipy.stats.mstats  import mquantiles

    fig = plt.figure(num=None, figsize=(10, 10), dpi=80, facecolor='w', edgecolor='k')
    plt.subplots_adjust(hspace=0, wspace = 0)
    c = ('r', 'g', 'b', 'y')
    y_min_lim = amin(data[:,1:])
    y_max_lim = amax(data[:,1:])

    for i in range(1,n_det+1):
	ax = fig.add_subplot(n_det+1,1,i)
	ax.xaxis.set_major_formatter( plt.NullFormatter() )
	ax.yaxis.set_major_formatter( MyFormatter() )
	plt.plot(pos_tvec*1e3,     data[:,i]/data[:,0]*1e6, c[i%4],  label = 'mc'+str((i-1)*4+1)+'/I$_p$')
	plt.plot(pos_tvec*1e3, retrofit[:,i]/data[:,0]*1e6, c[i%4]+'-.',  label = 'retrofit')
	
	ax.text(.05,.1,'$\\chi^2$ = %2.1f'% chi2[i],horizontalalignment='left',verticalalignment='bottom',transform=ax.transAxes)

	leg = plt.legend(loc='upper left', fancybox=True)
	leg.get_frame().set_alpha(0.5)
	data_max = mquantiles(data[:,i]/data[:,0]*1e6, 0.95)*2

	plt.ylim(0, data_max)
	plt.ylabel('mc'+str((i-1)*4+1)+'/$I_p$ [mT/kA]')


    ax = fig.add_subplot(n_det+1,1,n_det+1)
    plt.plot(pos_tvec*1e3,     ones(size(data,0)),   label = 'I$_p$/I$_p$')
    plt.plot(pos_tvec*1e3,     retrofit[:,0]/data[:,0],   '--', label = 'retrofit')
    ax.text(.05,.1,'$\\chi^2$ = %2.1f'% chi2[0],horizontalalignment='left',verticalalignment='bottom',transform=ax.transAxes)

    leg = plt.legend(loc='upper left', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    plt.ylim(0,2)
    plt.ylabel('$I_p/I_p$ [-]')
    plt.xlabel('t [ms]')
    plt.savefig('./graphs/retrofit.'+filetype,bbox_inches='tight')

    plt.close()


    print 'retrofit  ',time.time()-t
    t = time.time()
    

    
    #plot animation -- slow!!!
    #print "plotting animation" 
    #import MagFieldProfileGen
    #MagFieldProfileGen.PlotProfile(pos_tvec,data[:,0],position, 'animation')
    
    print 'graphs plotted in %g s' % (time.time()-t0)

def main():
    
    for path in ['graphs', 'results','constants' ]:
	if not os.path.exists(path):
	    os.mkdir(path)
	    
    if sys.argv[1] ==  "analysis":
	out = get_position()
	save('out', out)

    if sys.argv[1] ==  "plots":
	out = load('out.npy')
	os.remove('out.npy')
	graphs(out, 'png')
	saveconst('status', 0)



if __name__ == "__main__":
    main()
    	 
