#!/usr/bin/env python
# -*- coding: utf-8 -*-


# ======================   GEFIT 0.1 (Golem EFIT) ======================================
# plama position calculation is based on the least square fit of the signal from mirnov coils a plasma current.
#A position and new plasma current are calculated from the furmula for loop wire is calculeted mag, filed in the coils and compared 
#with measured values. It should by robust and it was tested on the poor data from mirnov coils. 
#magnetic filed from external coils can be calculated by setting extCoilsPos and extCoilsCurr, however it was not tested yet. 
#
# Autor: Tomáš Odstrčil   tomasodstrcil@gmail.com

#TODO problémy: při velké disrupci dojde k tak masivnímu skoku že selže 
#integrace a nastane  integrační drift na některé z cívek. Udělá to chybu v signálu i 20%

import time
#t = time.time()
from numpy import *
from numpy.linalg import norm,lstsq
#from scipy.interpolate import interpolate
from scipy.optimize import minimize 
from multiprocessing import Process, Pool, cpu_count
#from Deconvolution import * 
from pygolem_lite.modules import deconvolveExp
import sys   
from MagFieldCalc import *
#print 'import time PP:', time.time()-t


#constants
a = 0.085		#[m]
R_0 = 0.4		#[m]



#http://www.phy.uct.ac.za/courses/python/examples/fitresonance.py
#http://www.scipy.org/doc/api_docs/SciPy.optimize.minpack.html

	
def directConductorApprox(B): 
    R =  (B[3]-B[1])/(B[3]+B[1])*-0.093
    Z  = (B[4]-B[2])/(B[4]+B[2])*-0.093
    r = hypot(Z,R) 
    if r > a:
	Z *= a/r
	R *= a/r
    R+= R_0
    return R,Z
    





def fun(x,param,B_coils,y,Bout):
    #print x.shape, param

    B = Bloop_analytic(x[:,0],x[:,1],param[0]+R_0,param[1],Bout)*param[2]
    B += B_coils    #NOTE musí být správná orientace proudů! meřených flukama
    B *=  x[:,2:]
   
    y[1:] = sqrt(sum(square(B),1))
    y[0] = abs(param[2])
    
    return y

	
def MainCycle((tvec,data,x,sigma,Bext,extCoilsCurr)):
    #print data.shape
    #data = double(data)
    #x = double(x)
    #tvec = double(tvec)
    #sigma = double(sigma)
    #Bext = double(Bext)
    #extCoilsCurr = double(extCoilsCurr)
    #print data.shape
    #exit()
    
    N = size(data,0)
    n_det = size(data,1)

    position = empty((N, 2))

    retrofit = empty((N,n_det))
    reziduum = empty(N)
    
    I0 = median(data[:,0])
    steps = 0   
    y = empty(size(x,0)+1)
    Bout = empty((size(x,0),2))
    for t in xrange(N): 
        #print t 
        #sys.stdout.write('\r'+str(t))
   
        #sys.stdout.flush()

	R,Z = directConductorApprox(data[t,:])
	R+=0.04
	scale = 1/array((1,1,0.5/I0))   #1/a/2    
	params0 = array([R-R_0,Z,data[t,0]])

	if Bext == None:
	    B_coils = 0
	else:
	    B_coils = dot(Bext,extCoilsCurr[t,:])    
        
        T = time.time()
        
#def fitfun(params):
    
    #retrofit = fun(x,params*scale, B_coils,y,Bout)
    #retrofit-= data
    #retrofit/= sigma
    #retrofit *=retrofit
    #res = norm(retrofit)**2
    #res+= norm((params*scale)[:2]/(a*1.0))**50        
    ##res = norm((data[t,:] - fun(x,params*scale, B_coils,y,Bout))/sigma)**2+norm((params*scale)[:2]/(a*1.0))**50        
    
    #return res
        
	fitfun = lambda params : norm((data[t,:] - fun(x,params*scale, B_coils,y,Bout))/sigma)**2+norm((params*scale)[:2]/(a*1.0))**50	
       
	res = minimize(fitfun, params0/scale, method='Nelder-Mead',options={'xtol':1e-4, 'ftol' : 1e-2})
        #res = minimize(fitfun, params0*scale, method='BFGS',options={'xtol':1e-4, 'ftol' : 1e-2})

        #print '\n', time.time()-T
        #exit()
        #sys.stdout.write('\r'+str(time.time()-T))
   
        #sys.stdout.flush()
        
	if  not res.success:
	    res.x[:] = params0		
	    res.fun  = 1e6

	    print 'position calculation failure in ', str(tvec[t]), 's'
	else:
	    params0 =  res.x*scale

	
	position[t,:] = (res.x*scale)[:2]
	position[t,0]+= R_0
	steps+= res.nfev
	retrofit[t,:] = fun(x,params0,B_coils,y,Bout)
	reziduum[t] =  res.fun 

    return (position,retrofit,reziduum)
    
#http://www.phy.uct.ac.za/courses/python/examples/fitresonance.py
#http://www.scipy.org/doc/api_docs/SciPy.optimize.minpack.html CalcPlasmaPosition(tvec_ds,detectorPos, detectorSignal_ds, detectorDriftError, Ip_ds,IpDriftError, shot_num = shot_num)

def CalcPlasmaPosition(tvec,detectorPos, detectorSignal,detectorDriftError,
                    Ip,IpDriftError,extCoilsPos = None,extCoilsCurr= None ):

#očekávám stejné časové rozdělení vstupního signálu, všechny položky jsou už absolutně kalibrované, jednotky SI
    print 'CalcPlasmaPosition'
    t1 = time.time()
    n =  len(tvec)

 
    coil_perpend=detectorPos-mean(detectorPos,axis=0)
    coil_perpend[:,1]*= -1   
    coil_perpend=coil_perpend[:,::-1]/hypot(coil_perpend[:,1],coil_perpend[:,0])[:,None]
   
    x = hstack((detectorPos,coil_perpend))

    data  = vstack((Ip, detectorSignal.T)).T
    sigma = hstack((IpDriftError, detectorDriftError))

    if extCoilsPos == None or extCoilsCurr == None:
	Bext = None
	
    else:
	Bext = zeros(shape(detectorPos)+(size(extCoilsPos,0),))

	for i in arange(size(extCoilsPos,0)):
	    Bext[...,i] = Bloop_analytic(x[:,0],x[:,1],extCoilsPos[i,0],extCoilsPos[i,0])

    
    #main calculation , multiprocess
    n_cpu = cpu_count()
    p = Pool(n_cpu)

    #split_ind =  array_split(range(n), n_cpu*4)
    #print data.shape
    #exit()
    data_spl = array_split(data, n_cpu*4,axis=0)
    tvec_spl = array_split(tvec, n_cpu*4)
    inputs = [(tvec_spl[i],data_spl[i],x,sigma,Bext,extCoilsCurr) for i in range(n_cpu*4)] 
    
    list_data = p.map(MainCycle,inputs )

    p.close()
    p.join()
    
    
    list_position = list()
    #list_position_dc = list()
    list_retrofit = list()
    list_reziduum = list()
    
    for (position,retrofit,reziduum) in list_data:
	list_position.append(position)
	#list_position_dc.append(position_dc)
	list_retrofit.append(retrofit)
	list_reziduum.append(reziduum)
	
    position = vstack(list_position)
    #position_dc = vstack(list_position_dc)
    retrofit = vstack(list_retrofit)
    reziduum = hstack(list_reziduum)


    radius = a-sqrt((position[:,0]-0.4)**2+position[:,1]**2)
    chi2 = mean((data-retrofit)**2,axis=0)/sigma**2

    print 'position calculated in %g s' % (time.time()-t1)
    print 'chi2: ', chi2,' total ', norm(chi2)
    
    return tvec, position, radius, reziduum, retrofit, data,chi2

  

    
def RemoveDriftsAuto(signal, Bt,Uloop,plasma_start,plasma_end,Bt_trigger,E_trigger, Bt_crosstalk=True, Uloop_crosstalk=False, Stabil_crosstalk=False, Trafo_effect=True ):
    n_det = size(signal, 1)-1

    index = 2
    t_max = min(Bt[-1,0],Uloop[-1,0],signal[-1,0])  # čas po který  to ty diagnostiky berou
    t_min = max(Bt[0,0],Uloop[0,0],signal[0,0])
    dt = mean(diff(signal[:,0]))
    signal = signal[:int(t_max/dt+1),:]
    tvec  = signal[:,0]
    n =  size(tvec)

    #resample to the same resolution
    if Uloop_crosstalk:
	i_Uloop = index
	index+=1
	Uloop = interp(tvec, Uloop[:,0],Uloop[:,1], left=0, right=None)

    if Bt_crosstalk:
	i_Bt = index
	index+=1
	Bt = interp(tvec, Bt[:,0],Bt[:,1], left=0, right=None)

    if Trafo_effect:
	t_exp = 21.2e-3 #[s]

	i_Traf_res = index
	Bt_conv,retrofit = deconvolveExp(Bt,-t_exp, dt,0,0)
	index+=1
    if Stabil_crosstalk:
	i_St = index
	index+=1

	Bt = interp(tvec, fluke[:,0],fluke[:,1], left=0, right=None)
	

    if isnan(plasma_start) or  isnan(plasma_end):	    
	plasma_start = Bt_trigger
	plasma_end = Bt_trigger
  
    #if removePlasmaDrift:
	#drift_error = mean(cumsum((diff(signal[:,1:], axis = 0))**2, axis = 0), axis = 1)
	#drift_error = hstack((0,drift_error))
	#plot(drift_error)
	#show()
	#plot(diff(signal[:,1:], axis = 0))
	#show()
	#i_drift = index
	#index+=1


    Bt_trigger_n = argmin(abs(tvec-Bt_trigger))

    
    E_trigger_n  = argmin(abs(tvec- E_trigger))
    plasma_start_n = max(argmin(abs(tvec-plasma_start)),E_trigger_n)
    plasma_end_n = max(argmin(abs(tvec-plasma_end)),E_trigger_n)
    

    interval0 = range(int(t_min/dt),Bt_trigger_n)   #clean area before any trigger
    interval1 = range(Bt_trigger_n+int(0.5e-3/dt), E_trigger_n)  # short time between Bt and CD trigger  
    interval2 = range(E_trigger_n,plasma_start_n)  #interval between  CD trigger   and breakdown
    interval3 = range(plasma_end_n,n )   # interval beween plasma end and torroidal field end tyristor
    interval_plasma = range(plasma_start_n,plasma_end_n)  #plasma 
    interval = interval1+interval2+interval3   #interval used for removing of the drift
    #print i
    mainBasis = zeros((n, index))
    if Bt_crosstalk:
	mainBasis[:,i_Bt] = Bt
    if Uloop_crosstalk:
	mainBasis[:,i_Uloop] = Uloop
    if Trafo_effect:
    	mainBasis[:,i_Traf_res] = Bt_conv
    if Stabil_crosstalk:
    	mainBasis[:,i_St] = fluke
    #if removePlasmaDrift:
	#mainBasis[:,i_drift] = drift_error

    #offset, integration drift
    offset = zeros(n)
    offset[interval0] = 1
    mainBasis[:,0] =  cumsum(offset, out = offset)
    offset = ones(n)
    offset[interval0] = 0
    mainBasis[:,1] =  cumsum(offset, out = offset)
    #offset = zeros(n)
    #offset[interval_plasma] = 1
    #mainBasis[:,2] =  cumsum(offset, out = offset)

    x,res,rank,s =  lstsq(mainBasis[interval,:],signal[interval,1:])

    corr_signal = signal[:,1:]-dot(mainBasis, x)
    res = sum((corr_signal[interval,:])**2,0)
    res*= 1e15/len(interval) 
    res = sqrt(res)
    
    #prepare graph
    frames = list()
    vlines = list()
    hlines = list()
    rectangles = list()
    vlines.append((plasma_end*1000,'--' ))
    vlines.append((plasma_start*1000,'--' ))
    hlines.append(( 0, '-.'))

    for i_det in arange(n_det):
	
	curves = list()
	curves.append((single(1e6*signal[:,i_det+1]),'Original signal', '--' ))
	curves.append((single(1e6*corr_signal[:,i_det]), 'Corrected signal', 'k-'))
	curves.append((single(1e6*dot(mainBasis[:,:2],x[:2,i_det])),'Integration offset','-' ))
	
	if Bt_crosstalk:
	    curves.append((single(1e6*mainBasis[:,i_Bt]*x[i_Bt,i_det]), 'Toroidal mag. field', '-'))	
	    
	if Trafo_effect:
	    curves.append((single(1e6*mainBasis[:,i_Traf_res]*x[i_Traf_res,i_det]),'Trafo mag. field','-'))
	
	if Uloop_crosstalk:
	    curves.append((single(1e6*mainBasis[:,i_Uloop]*x[i_Uloop,i_det]),'Uloop','-'))
	
	frames.append((1000*tvec,curves,'mc'+str(i_det*4+1)+' [ mV$\cdot$ms]','$\\chi^2$ = %2.1f'%res[i_det]))
	
	intervals = (interval1,interval2,interval3 )
    
    for inter in intervals:
	if len(inter) == 0:
	    continue
	(x_min,x_max) = (tvec[inter][0]*1e3, tvec[inter][-1]*1e3)
	rectangles.append((x_min,x_max))
    
    AutoRemoveGraph = (frames, vlines,hlines,rectangles)

    
    
    
    #save projection constants
    proj_file = 'mc1\tmc5\tmc9\tmc13\n'
    proj_file+='drift1 [V] '+str(x[0,:])+'\n'
    proj_file+='drift2 [V] '+str(x[1,:])+'\n'

    if Bt_crosstalk:
	    proj_file+='Bt crosstalk [V/T]\t'+str(x[i_Bt,:])+'\n'
    if Trafo_effect:
	    proj_file+='Trafo crosstalk (respons time %1.1f ms)[V/T]\t'%(1000*t_exp) +str(x[i_Traf_res,:])+'\n'
    if Uloop_crosstalk:
	    proj_file+='Uloop (residual) crosstalk [V/V]\t'+str(x[i_Uloop,:])+'\n'

    f = open('./constants/projections.txt', 'w')
    f.write(proj_file)
    f.close()
    
   

    return vstack((tvec,corr_signal.T)).T,res,AutoRemoveGraph
    
    
  