#!/usr/bin/python2
# -*- coding: utf-8 -*-


####     Microwaves 2.0
#This algorithm calculate transformation of the sin signal from microwaves density measurement to the phase/amplitude space. 
# First step of the calculation is estimate of the base frequency and calculation of the complex exponential
#with the same frequency.In the second step is signal multiplied by this exponential
#and resulting low frequency signal is smoothed over Gaussian window. Finally complex phase and amplitude are calculated.  

# Authors: Tomas Odstrcil, Ondrej Grover

from time import time
t = time()

from numpy import *
#from numpy.fft import fft, ifft
#from scipy.signal import fftconvolve   
from scipy.constants import c,m_e,epsilon_0,e
from pygolem_lite.modules import save_adv,load_adv,saveconst
from pygolem_lite import Shot
import os
import sys
import fftw3
from scipy.signal.signaltools import _centered

print 'include time ',time()-t 

def wfft(a,ext,nthreads=4):  #osekal jsem to a už to funguje jen pro 1D pole
    n = len(a)
    if ext!= None:
	a = append(a, zeros(ext[0]-n, dtype=a.dtype))

    a = a.astype('complex')
    outarray = a.copy()
    fft_forward = fftw3.Plan(a,outarray, direction='forward', flags=['estimate'], nthreads=nthreads)
    fft_forward()
    return outarray
    
def wifft(a,ext = None,nthreads=4):
    n = len(a)
    if ext!= None:
	a = append(a, zeros(ext[0]-n, dtype=a.dtype))

    a = a.astype('complex')
    outarray = a.copy()  #empty_like
    fft_backward = fftw3.Plan(a,outarray, direction='backward', flags=['estimate'], nthreads=nthreads)
    fft_backward()
    return outarray
    
    
def fftconvolve(in1, in2, mode="full"):
    """Convolve two N-dimensional arrays using FFT. See convolve.

    """
    s1 = array(in1.shape)
    s2 = array(in2.shape)
    complex_result = (issubdtype(in1.dtype, complex) or
                      issubdtype(in2.dtype, complex))
    size = s1 + s2 - 1

    # Always use 2**n-sized FFT
    fsize = 2 ** int_(ceil(log2(size)))
    IN1 = wfft(in1, [fsize,])
    #print shape(IN1),shape(in1),fsize,shape(wfft(in2, [fsize,]))
    IN1 *= wfft(in2, [fsize,])
    fslice = tuple([slice(0, int(sz)) for sz in size])
    ret = wifft(IN1)[fslice].copy()
    del IN1
    if not complex_result:
        ret = ret.real
    if mode == "full":
        return ret
    elif mode == "same":
        if product(s1, axis=0) > product(s2, axis=0):
            osize = s1
        else:
            osize = s2
        return _centered(ret, osize)
    elif mode == "valid":
        return _centered(ret, abs(s2 - s1) + 1)




def Demodulation(data,win):
    t = time()
    y = copy(data)
    y-= mean(data, axis = 0)
    

    n = size(y,0)
    fourier = wfft(y[:,0], shape(y[:,0])) #calulcate the fourier transfrom for the sine data
    
    
    max_frequency_index =  argmax(abs(fourier))
    max_frequency = abs(fourier[max_frequency_index])
    fourier[:] = 0 #cancel out all other frequencies
    fourier[max_frequency_index] = max_frequency

    
    cmpl_exp = wifft(fourier, shape(fourier)) 

    gauss = exp(-arange(-3*win,3*win)**2/win**2)  
    
    signal = list()
    for i in range(size(y,1)):
	signal.append(fftconvolve(y[:,i]*cmpl_exp,gauss,mode='same' ))
    signal = array(signal, copy = False).T

    amplitude = abs(signal)
    phase = angle(signal)

    #for i in range(size(y,1)):
	#diff_phase = diff(phase[:,i], axis = 0)
	#phase[1:,i] -= cumsum(where(abs(abs(diff_phase) - 2*pi) < 1, diff_phase, 0), axis = 0, out = diff_phase)

    phase = unwrap(phase, axis = 0)

    print 'calc. time', time()-t
    return (amplitude,phase)

def LoadData():
    Data = Shot()

    gd = Shot().get_data
    tvec, density1  = gd('any', 'density')
    tvec, density2  = gd('any', 'density_2')
       
    return tvec,density1,density2


    
    
def graphs():
    
    import matplotlib 
    matplotlib.rcParams['backend'] = 'Agg'
    matplotlib.rc('font',  size='10')
    matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!
    import matplotlib.pyplot as plt
    
    class MyFormatter(plt.ScalarFormatter): 
	def __call__(self, x, pos=None): 
	    if pos==0: 
		return '' 
	    else: return plt.ScalarFormatter.__call__(self, x, pos) 
	    
	    
    tvec, phase_pila = load_adv('results/phase_saw')
    tvec, phase_sinus = load_adv('results/phase_sinus')
    tvec, phase = load_adv('results/phase_substracted')
    tvec, amplitude = load_adv('results/amplitude_sinus')   
    tvec, n_e = load_adv('results/electron_density')

    
    
    fig = plt.figure(num=None, figsize=(10, 6), dpi=80, facecolor='w', edgecolor='k')
    plt.subplots_adjust(hspace=0, wspace = 0)
    
    ax = fig.add_subplot(2,1,1)
    ax.xaxis.set_major_formatter( plt.NullFormatter() )
    ax.yaxis.set_major_formatter( MyFormatter() )
    plt.plot(tvec*1000,-phase_pila+mean(phase_pila),'--', label = 'saw phase' )
    plt.plot(tvec*1000,-phase_sinus+mean(phase_sinus),'--', label = 'signal phase')
    plt.plot(tvec*1000,phase, 'k',label = 'substracted phase')
    plt.axis('tight')

    plt.xlim(0,None)
    plt.xlabel('time [ms]')
    plt.ylabel('phase [rad]')
    leg = plt.legend(loc='best', fancybox=True)
    leg.get_frame().set_alpha(0.5)
   
    ax = fig.add_subplot(2,1,2)
    plt.plot(tvec*1000,amplitude,label = 'amplitude')
    plt.xlim(0,None)
    plt.ylim(0,None)
    leg = plt.legend(loc='best', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    
    plt.ylabel('amplitude [a.u.]')
    plt.savefig('graphs/demodulation.png',bbox_inches='tight')
    plt.close()
    
    
    Data = Shot()
    plasma_start = Data['plasma_start']
    plasma_end = Data['plasma_end']
    fig = plt.figure(num=None, figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')
    plt.plot(tvec*1000,n_e/1e19,label = '$n_e$')
    plt.ylabel('$<n_e>$ [$10^{19}\,m^{-3}$]')
    plt.xlabel('time [ms]')
    plt.xlim(0,20)
    plt.ylim(0,None)
    plt.axvline(x = 1000*plasma_start,linestyle = '--')
    plt.axvline(x = 1000*plasma_end, linestyle = '--')

    plt.savefig('graphs/electron_density.png',bbox_inches='tight')
    plt.close()
    
    
    

def main():



    
    for path in ['graphs', 'results' ]:
	if not os.path.exists(path):
	    os.mkdir(path)
	    
    if sys.argv[1] ==  "analysis":
	
	win = 30e-6 #[s]
	t = time()
	tvec,density1,density2 = LoadData()
	dt = (tvec[-1]-tvec[0])/len(tvec)
	print 'load time ', time()-t
	signals = vstack((density2,density1)).T
	(amplitude,phase) = Demodulation(signals,win/dt)  
	
	downsample = int(win/dt/2)    
	amplitude = amplitude[::downsample,1]
	phase_pila = phase[::downsample,0]
	phase_sinus = phase[::downsample,1]
	tvec = tvec[::downsample]
	
	phase = phase_pila-phase_sinus
	phase -= median(phase)
	
	save_adv('results/phase_saw', tvec, phase_pila)
	save_adv('results/phase_sinus', tvec, phase_sinus)
	save_adv('results/phase_substracted', tvec, phase)
	save_adv('results/amplitude_sinus', tvec, amplitude)    

	a = 0.01   #[m]
	f_0 = 75e9 #[Hz]
	lambda_0 = c/f_0
	n_e = 4*pi*m_e*epsilon_0*c**2/(e**2*lambda_0*2*a)*phase
	save_adv('results/electron_density', tvec, n_e)

    if sys.argv[1] ==  "plots":
	graphs()
	saveconst('status', 0)



if __name__ == "__main__":
    main()
    	 
