#!/usr/bin/python2
# -*- coding: utf-8 -*-
#fASTER VERSION THAN MAIN_2, MORE COMPLICATED


####     Microwaves 2.0
#This algorithm calculate transformation of the sin signal from microwaves density measurement to the phase/amplitude space. 
# First step of the calculation is estimate of the base frequency and calculation of the complex exponential
#with the same frequency.In the second step is signal multiplied by this exponential
#and resulting low frequency signal is smoothed over Gaussian window. Finally complex phase and amplitude are calculated.  

# Authors: Tomas Odstrcil, Ondrej Grover

from time import time
t = time()
#import matplotlib 
#matplotlib.rcParams['backend'] = 'Agg'
#matplotlib.rc('font',  size='10')
#matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!
#import matplotlib.pyplot as plt

from numpy import *
from scipy.fftpack import fft, ifft,fftfreq
from scipy.signal import fftconvolve, medfilt 
from pygolem_lite import save_adv,load_adv,saveconst, Shot
from pygolem_lite.modules import multiplot, get_data, paralel_multiplot
import os
import sys
from matplotlib.pylab import *
from  BlockConv import BlockConv



from scipy.constants import c,m_e,epsilon_0,e
a = 0.085   #[m]
f_0 = 75e9 #[Hz]
lambda_0 = c/f_0
ne_0 = 4*pi*m_e*epsilon_0*c**2/(e**2*lambda_0)


def Demodulation(y,win,dt):
    #demodulation is based on Hilber transdformation
    t = time()
   
    y-= mean(y, axis = 0)
    n = size(y,0)
    N = 2 ** int(ceil(log2(n)))

    fourier = fft(y[:,0], N) #calulcate the fourier transfrom for the sine data

    #substract the varing offset of the signals
    reduct = 5*999
    n_ = (n/reduct)*reduct
    c = reshape(y[:n_,:],(n_/reduct,reduct,2))
    offset = mean(c,axis=1)
    c -= offset[:,newaxis,:]
    c = swapaxes(c,0,1).reshape(-1,2)


    #find the carrier frequency
    max_frequency_index =  argmax(abs(fourier[:N/2]))
    
    f = fftfreq(N,dt)
    s = slice(max_frequency_index-100,max_frequency_index+100)
    amplitude = abs(fourier[s])
    f_carrier = sum(f[s]*amplitude)/sum(amplitude)
 

    
    #find a unharmonics factor
    amplitude = linalg.norm(amplitude)
    norm1 = linalg.norm(fourier[:N/2])
    
    k = sqrt(norm1**2-amplitude**2)/norm1
    
    
    fourier[:] = 0 #cancel out all other frequencies
    fourier[max_frequency_index] = 1

    
    cmpl_exp = ifft(fourier)[:n]

    gauss = exp(-arange(-3*win,3*win)**2/win**2)  
    gauss/= sum(gauss)
  
    signal = list()
    for i in range(size(y,1)):
	signal.append(BlockConv(y[:,i]*cmpl_exp,gauss,mode='same' ))  #BUG use IIR filtfilt!!!
	
    signal = array(signal, copy = False).T    


    print 'calc. time', time()-t
    return signal,norm1/(n/2),f_carrier,k
    
    #return amplitude,phase,f_carrier,k,norm1/(n/2)

def LoadData():
    Data = Shot()
    Bt_trigger = Data['Tb']

    gd = Shot().get_data
    
    
    if Shot()['shotno']  > 22280:
        tvec, density1  = gd('any', 'interframpsign')
        tvec, density2  = gd('any', 'interfdiodeoutput')
         
    elif Shot()['shotno']  > 21300:
        tvec, density1  = gd('any', 'tek_ch1')
        tvec, density2  = gd('any', 'tek_ch3')
    
    elif Shot()['shotno']  > 18674:
        tvec, density1  = gd('any', 'interframpsign')
        tvec, density2  = gd('any', 'interfdiodeoutput')
    
    else:
        tvec, density1  = gd('any', 'density1')
        tvec, density2  = gd('any', 'density2')
        
    start  = Data['plasma_start']
    end  = Data['plasma_end']

       
    return tvec, start, end, density1,density2,Bt_trigger


    
    
def graphs():
	    
    tvec, phase_pila = load_adv('results/phase_saw')
    tvec, phase_sinus = load_adv('results/phase_sinus')
    tvec, phase = load_adv('results/phase_substracted')
    tvec, phase_corr = load_adv('results/phase_corrected')

    tvec, amplitude = load_adv('results/amplitude_sinus')   
    tvec, n_e = load_adv('results/electron_density')
    tvec, ne_corr = load_adv('results/electron_density_corr')
    #print ne_corr
    #print phase_corr

    #import IPython
    #IPython.embed()
    data = [[get_data([tvec,-phase_pila+mean(phase_pila)], 'phase 1', 'phase [rad]', xlim=[0,40], fmt="--"), 
	    get_data([tvec,-phase_sinus+mean(phase_sinus)], 'phase 2', 'phase [rad]', xlim=[0,40], fmt="--" ), 
	    get_data([tvec,phase], 'substracted phase', 'phase [rad]', xlim=[0,40], fmt="k" ), 
	    get_data([tvec,phase_corr], 'corrected phase', 'phase [rad]', xlim=[0,40], fmt="k:" )], 
	    get_data([tvec,amplitude], 'amplitude', 'amplitude [a.u.]', xlim=[0,40],ylim=[0,None]  )]
    multiplot(data, ''  , 'graphs/demodulation', (10,6) )

    
    ylim = 1 if amax(n_e) < 1e18 else None

    jump = ne_0*pi/(2*a)+tvec*0 
    data = [[ get_data('electron_density', 'Average electron density', '$<n_e>$ [$10^{19}\,m^{-3}$]',data_rescale=1e-19,ylim=[None,ylim] ),
            get_data([tvec, ne_corr], 'Corrected electron density', '$<n_e>$ [$10^{19}\,m^{-3}$]',data_rescale=1e-19,c='r' ) ]+
            [ get_data([tvec,j*jump],  data_rescale=1e-19,c='y' ) for j in range(7)   ]]
    
    
    
    multiplot(data, ''  , 'graphs/electron_density', (9,3) )
    
    
    
    
    #jump = ne_0*pi/(2*a)+tvec*0 
    #data = get_data('electron_density_corr', 'Average electron density', '$<n_e>$ [$10^{19}\,m^{-3}$]')
    
    
    
    #multiplot(data, ''  , 'graphs/electron_density', (9,3) )
    
    
    
    
    
    
    
    paralel_multiplot(data, '', 'icon', (4,3), 40)
    #os.system('convert -resize 150 graphs/electron_density.png icon.png')



def RobustDensityUnwrap(tvec, amplitude, phi0,start, end,  n_points ):
    
    t = time()
    #phi = unwrap(angle(cmplx_signal[:,0]/cmplx_signal[:,1]))  #naive unwrapping
    
    phi0 -= median(phi0[tvec<start])
    #A = abs(cmplx_signal[:,0])

                        
    #import IPython
    #IPython.embed()
    
    
    def fun(y,x,start,end, A, phi0,tvec):
        x = r_[start,x,end]
        y = r_[0,y,0]
        phi = interp(tvec,x,y)
        c1 = linalg.norm(A*((phi0-phi-pi/2)%pi-pi/2))/len(tvec)  # minimize distance frm the data modulus pi 
        c2 = 10*sum(-phi[phi<0])/len(tvec)     #positivity
        c3 = sum(abs((diff(y))))/100/sqrt(len(x))         #smoothness
        cost = c1+c2+c3
        #print c1,c2,c3
        return cost
    
    from scipy.optimize import basinhopping
        

    #x0 = pi*ones(n_points)
    x = linspace(start,end,n_points+2)[1:-1]
    
    
    y0 = interp(x,tvec,maximum(0,phi0))
    args = x,start,end, amplitude/median(amplitude), phi0,tvec
    out = basinhopping(fun, y0, 10, 1,2, {'args':args})
    y = out.x
    
    phi_robust = interp(tvec,r_[start,x,end],r_[0,y,0])

    print 'RobustDensityUnwrap: ', time()-t 
    

    
    
    phi_= interp(tvec,r_[start,x,end],r_[0,y,0])
    plot(tvec, (phi0-phi_robust-pi/2)%(pi)-pi/2)
    plot(tvec, phi0)
    plot(tvec, phi_robust)
    phi__ = (phi0-pi/2)%pi  +(phi_robust-(phi_robust-pi/2)%pi)
    phi__[phi__<-1] += pi
    #plot(tvec, phi__)
    #ylim(-2,7)
    savefig('plot')
    clf()

    

    
    return phi_robust




def main():
    
    for path in ['graphs', 'results' ]:
	if not os.path.exists(path):
	    os.mkdir(path)
	    
    if sys.argv[1] ==  "analysis":
        print 'analysis'
	
	win = 30e-6 #[s]

	t = time()
	#if the phase difference is negativ, change order of dens1,dens2

        tvec,start, end, density2,density1,Bt_trigger = LoadData()
	
	
	dt = (tvec[-1]-tvec[0])/len(tvec)
	density1 = density1[tvec>0]
	density2 = density2[tvec>0]
	tvec = tvec[tvec>0]

	if std(density1)  >  std(density2):
	    density1,density2 = density2,density1
	
	
	print 'load time ', time()-t
	signals = vstack((density1,density2)).T
        cmplx_signal,norm_ampl,f_carrier,k = Demodulation(signals,win/dt,dt)  
	
        downsample = int(win/dt/2) 
        
        
                
        #import IPython
        #IPython.embed()
        
        
        #tvec = tvec[::downsample]
        
        
  
        #phi = unwrap(angle(cmplx_signal[::downsample,0]/ cmplx_signal[::downsample,1]))
        
        #phi -= median(phi[tvec<start])
        #A = abs(cmplx_signal[::downsample,0])

        
        
        #def fun(y,x,start,end, A, phi,tvec):
            #x = r_[start,x,end]
            #y = r_[0,y,0]
            #phi_ = interp(tvec,x,y)
            #c1 = linalg.norm(A*((phi-phi_-pi/2)%pi-pi/2))/len(tvec)
            #c2 = 10*sum(-phi_[phi_<0])/len(tvec)
            #c3 = std(y)/100
            #cost = c1+c2+c3
            ##print c1,c2,c3
            #return cost
        
        #from scipy.optimize import basinhopping
            

        #N = 20
        #x0 = pi*ones(N)
        #x = linspace(start,end,N+2)[1:-1]
        #out = basinhopping(fun, x0, 10, 0.02,0.5, {'args':(x,start,end, A/median(A), phi,tvec)})
        #y = out.x
        
       
       
        #phi_= interp(tvec,r_[start,x,end],r_[0,y,0])
        #plot(tvec, (phi-phi_-pi/2)%(pi)-pi/2)
        #plot(tvec, phi)
        #plot(tvec, phi_)
        #phi__ = (phi-pi/2)%pi  +(phi_-(phi_-pi/2)%pi)
        #phi__[phi__<-1] += pi
        #plot(tvec, phi__)
        #ylim(-2,7)
        #savefig('plot')
        #clf()



        
        
        
        #plot(tvec,phi)
        #plot(tvec,A/median(A)*10)
        #plot(  r_[start,x,end] , r_[0,out.x,0] )
        #plot(tvec,phi__ )
        #[axhline(i*pi) for i in range(5)]
        #axvline(start)
        #axvline(end)
        #savefig('plot')
        #clf()
        

        
        tvec = tvec[::downsample]
        amplitude = abs(cmplx_signal[::downsample,:])
        phase = angle(cmplx_signal[::downsample,:])
        phase = unwrap(phase, axis = 0)

        phase_pila, phase_sinus = phase.T
        phase = phase_pila-phase_sinus

        
        
	
	#phase_pila = phase[::downsample,0]
	#phase_sinus = phase[::downsample,1]
	
	if tvec[0]< Bt_trigger:
            phase -= median(phase[tvec<Bt_trigger])
        elif start > tvec[0] and start < tvec[-1]:
            phase-= phase[tvec.searchsorted(start)]
        elif end > tvec[0] and end < tvec[-1]:
            phase-= phase[tvec.searchsorted(end)]
        else:
            phase-= phase[0]
            
        
	#print median(phase[tvec<Bt_trigger])
	
	switched = sign(mean(phase[(tvec > start) & (tvec < end)])) == -1
	if switched:
	    phase *= -1   # rotate the density of cabels were switched 

	#amplitude = amplitude[::downsample,:]
	amplitude *= norm_ampl/median(amplitude,0)[None,:]
	
	
        #phase_robust = RobustDensityUnwrap(tvec, amplitude[:,0], phase,start, end,  30 )

        phase_robust = phase

	
	#apl0 = median(amplitude[tvec < start])
	
	#print 
	
	##############  detekce skoku ################
	#t0 = time()
	
	#dwin = 20;
	#N = len(phase)
	#phase_diff = zeros(N)
	#for i in arange(N):
	    #p_tmp = phase[max(i-dwin,0):min(i+dwin,N-1)]
	    #phase_diff[i] = amax(p_tmp) - amin(p_tmp)
	    
	#ind = medfilt((amplitude < 0.8*apl0) & (phase_diff > 2), 3)  # remove standalone points
	#ind = where(ind)[0]
	#ind_skip = where(diff(ind)> 1)[0]
	#ind_skip = unique(concatenate([[ind[0]], ind[ind_skip], ind[ind_skip+1] , [ind[-1]]]))  # find indexes with skips

	#phase_new = phase.copy()
	#if mod(len(ind_skip), 2) == 0 and len(ind_skip) < 20: # fix only the simple issues 
	    #Nskip = len(ind_skip)
	    #for i in arange(0,Nskip,2):
		#i0 = ind_skip[i]
		#i1 = ind_skip[i+1]
		#phase_new[i1:] += phase_new[i0] - phase_new[i1]
		#phase_new[i0:i1] = nan

	#print "time detekce skoku", time() - t0

		

	
	#plot(phase_diff)
	#plot(amplitude / amax(amplitude) * amax(phase_diff))
	##plot(ind*amax(phase_diff))
	#savefig('diff.png')
	#close()
	#phase_new = phase
	#plot(phase)
	#plot(phase_new)
	#savefig('phase.png')
	#close()
	save_adv('results/phase_saw', tvec, phase_pila)
	save_adv('results/phase_sinus', tvec, phase_sinus)
	save_adv('results/phase_substracted', tvec, phase)
	save_adv('results/amplitude_sinus', tvec, amplitude)    
	save_adv('results/phase_corrected', tvec, phase_robust )    

	p = exp(1-mean((norm_ampl/amplitude)**2.5))
        ind_plasma = (tvec > start) & (tvec < end)

	n_e = ne_0*phase
	ne_corr = ne_0*phase_robust
	#electron_density_corr
	
	save_adv('results/electron_density_line', tvec, n_e)
	saveconst('results/electron_density_mean', mean(n_e[ind_plasma]))
	
	#print phase_sinus

	n_e /= 2*a
	ne_corr/= 2*a
	save_adv('results/electron_density', tvec, n_e)
        save_adv('results/electron_density_corr', tvec, ne_corr)

	#phase_skip  = abs(phase[0] - phase[-1])/2*pi 
	#negativity = 1-sum(phase[ind_plasma])/sum(abs(phase[ind_plasma]))
	#cum_var = mean(abs(diff(phase[ind_plasma]))) / mean(abs(phase[ind_plasma]))
	
	saveconst('results/carrier_freq', abs(f_carrier))
	saveconst('results/harmonics_distortion', k)
	saveconst('results/norm_ampl', norm_ampl)
	saveconst('results/probability', p)
	saveconst('results/reliability', 0 )
	#print 'phase_skip', phase_skip
	#print 'negativity', negativity
	#print 'cum_var', cum_var
        ##saveconst('results/reliability', phase_skip + negativity+cum_var*10 )

	#norm_ampl
	print 'carrier_freq ', f_carrier
	print 'harmonics_distortion ',k
	print 'norm_ampl ',norm_ampl
	print 'probability ',p
	#print "cum_var", cum_var
	#print "negativity", negativity
	#print "phase_skip", phase_skip

	
	


    if sys.argv[1] ==  "plots":
	graphs()
	saveconst('status', 0)



if __name__ == "__main__":
    main()
    	 
