#!/usr/bin/env python 
# -*- coding: utf-8 -*-

###########################################################################################################

#############   Algorithm which from ion intensities calculate the coefitients to estimate electron t
#############   temperature and impurity density. 

###########################################################################################################

ion_names =  ['HI','OI','OII', 'OIII', 'HeI', 'CII', 'CIII', 'NI','NII', 'NIII','M1', 'MIV' ]
element_names = ['H','O','He','C','N','M1','M2']
ions_index = [[0], [1,2,3],  [4], [5,6], [7,8,9], [10],[11]]
single_ion = [0,2,5,6]
multi_ion  = [1,3,4]
import matplotlib 
#matplotlib.rcParams['backend'] = 'Agg'
#matplotlib.rc('font',  size='10')
#matplotlib.rc('text', usetex=True)  # FIXME !! nicer but slower !!!
from numpy import *
from matplotlib.pyplot import *
import os
from numpy.linalg import norm
from scipy.linalg import qr,inv,pinv,eig,cholesky,solve_triangular
from scipy.stats.stats import pearsonr,hmean
from scipy.sparse import block_diag
from scipy.optimize import fmin_powell
from numpy.matlib import repmat
from time import time
from scipy.stats.stats import nanmean,nanmedian

n_ions = 12
n_elemets = 7
k = 2  #polynome order

err = load('err.npy')
T = load('temperature.npy')
shots = load('shots.npy')
n = len(T)
errT = err[-(1+n_elemets)*n:-n_elemets*n]
#errDensity = reshape(err[-n*n_elemets:],(n,n_elemets)).T
T_spitzer = zeros_like(T)
T_spitzer_err = zeros_like(T)

T_all= zeros(11000-7100)
shots_all = arange(7100,11000)

for shot in range(7100,11000):
    
    #print len(where(shot == int_(shots)) 
    shot_index = where(abs(floor(shots)- shot) < 0.5)[0]
    n_spec = size(shot_index)
    ##if n_spec == 0 :
	##continue
    
    try:
	#print 'loadim', shot
	try:
	    data = load('./electronTemperatures/'+str(shot)+'/electron_temperature.npz')
	    temp = data['data']*data['scale']
	    #plot(temp)
	    #show()
	    tvec= linspace(data['t_start'], data['t_end'], len(temp))
	    try:
		plasma_start = loadtxt('./electronTemperatures/'+str(shot)+'/PlasmaStart')
		plasma_end =   loadtxt('./electronTemperatures/'+str(shot)+'/PlasmaEnd')
	    except:		
		plasma_start =  amin(tvec[isfinite(temp)])
		plasma_end   =  amax(tvec[isfinite(temp)])
	    #print plasma_start,plasma_end

	    #print 'nacteno npz', (temp)
	except:
	    #ss
	    temp = loadtxt('./electronTemperatures/'+str(shot)+'/ElectronTemperatureMedianFilter.txt')
	    tvec = temp[:,0]
	    temp = temp[:,1]
	    savez('./electronTemperatures/'+str(shot)+'/electron_temperature.npz', data=temp, scale=1, t_start=tvec[0],t_end=tvec[-1])
	#print shape(temp)
	#try:
	    plasma_start = loadtxt('./electronTemperatures/'+str(shot)+'/PlasmaStart')
	    plasma_end =   loadtxt('./electronTemperatures/'+str(shot)+'/PlasmaEnd')
	#except:
	 
	    #print 'no lasma start'
	    #continue
	    
    except:
	print shot, ' failured'
	continue
 
    print shot
    if shot > 9300:
	temp*= 2
	
    else:
	tvec/= 1000
	plasma_start-= 0.5
	plasma_start/= 1000
	plasma_end/=1000
	
    #plot(tvec,temp)
    #xlim(plasma_start,plasma_end)
    #ylim(0,80)
    #savefig('./electronTemperatures/'+str(shot)+'.png')
    #clf()
    T_all[shot-7100] = median(temp[(tvec > plasma_start) * (tvec< plasma_end)])
    tvec_spec = linspace(plasma_start,plasma_end,n_spec+1)
    #plot(temp[:,0],temp[:,1])
    #xlim(plasma_start,plasma_end)
    #ylim(0,100)
    #show()
    for i in range(n_spec):
	shot_temp = temp[(tvec > tvec_spec[i]) * (tvec< tvec_spec[i+1])]
	T_spitzer[shot_index[i]] = nanmedian(shot_temp)
	T_spitzer_err[shot_index[i]] = nanmean(abs(shot_temp-nanmedian(shot_temp)))*1.4
    #print T_spitzer[shot_index] 
    #exit()
	
plot(shots_all,T_all, '.')	
show()
#plot(T_spitzer,-T, '.',markersize = 0.5)
##errorbar(T_spitzer,-T,yerr = errT, xerr = T_spitzer_err, fmt='.',capsize = 0,linewidth = 0.1,markersize = 0.5)
#ylim(-4,2)
#xlim(3,100)
#show()
#plot(shots,T_spitzer)
ind = where((T_spitzer!= 0))
T = -T[ind]
errT = errT[ind]

T_spitzer = T_spitzer[ind]
T_spitzer_err = T_spitzer_err[ind]

n = size(ind)
print n 
iter = 0
def f(var,x,xerr,y,yerr):
    #print sum(x),norm((xerr))
    T_spitz = var[-n:]
    a = var[0]
    b = var[1]
    c = var[2]
    d = var[3]

    #y = param[0]
    #yerr = param[1]
    #x = param[2]
    #xerr = param[3]
    #T_new = zeros_like(T_spitz)
    #print shape(T_new), shape(y), shape(yerr), shape(T_spitz), shape(x), shape(xerr)

    #ind = T_spitz*a+b> 1e-3
    #T_new[ind] = log((T_spitz*a+b)[ind])
    T_new = d*arctan(T_spitz*a+b)+c
    #T_new = 
    boundary = 0
    #ind = argsort(T_spitz)
    #print shape(ind)
    #boundary = norm(diff(T_new[ind]) < 0)
    #boundary = 0#norm((T_spitz*a+b)[~ind])
    #print shape(T_new), shape(y), shape(yerr), shape(T_spitz), shape(x), shape(xerr)
    #resid = norm((T_new-y)/yerr)**2+norm((T_spitz-x)/xerr)**2
    Resid = hstack(((T_new-y)/yerr, (T_spitz-x)/xerr))
    max_Resid = double(amax(abs(Resid)))
    Resid/= max_Resid
    Resid **= 2
    chi2 = max_Resid*sum(Resid/sqrt((1/max_Resid)**2 + Resid))/(size(y)+size(x)-size(var))
    #print norm((T_new-y)/yerr),norm((x)), sum(x),norm((1/xerr))
    #chi2 = resid/(size(y)+size(x)-size(var))
    #print resid
    global iter
    if iter%1000 ==0:
	print iter, chi2,boundary
    iter+=1
    #return
    return chi2+boundary
    
def TotalLqr(f, x,xerr,y,yerr):
    #plot(yerr)
    #show()
    #plot(x)
    #show()
    xerr[~isfinite(xerr)] = infty
    #print where(1-isfinite(xerr))
    #xerr+= 0.01
    #print sum(x)0.231930954473 -2.25248138786 -0.934499469238

    x0 = hstack((0.23,-2.25,-0.93,2,x))
    var = x0
    try:
	#aa
	var = load('var.npy')
    except:
	var = fmin_powell(f,x0,args=(x,xerr,y,yerr), xtol=1e-5,ftol=1e-5,maxfun=1e6,maxiter=1e7)
	save('var',var)
	
    T_spitz = var[-n:]
    a = var[0]
    b = var[1]
    c = var[2]
    d = var[3]
    #yerr += 0.1
    #xerr += 1
    print  a, b ,c,d
    errorbar(x,y,yerr=yerr, xerr=xerr, fmt='.',capsize = 0,linewidth = 0.1,markersize = 0.5)
    T_new = d*arctan(T_spitz*a+b)+c

    plot(T_spitz,T_new, '.')
    show()
    
    plot(x,y,'.',markersize = 0.5)
    
    T_new = d*arctan(T_spitz*a+b)+c

    plot(T_spitz,T_new, '.')
    show()
    
    ind = where((y-c)/d > pi)
    y[ind] = a*0.95*pi+c
    T = (tan((y-c)/d)-b)/a
    Terr = 1/(a*d)/cos((y-c)/d)**2*yerr
    errorbar(x,T,yerr=Terr,xerr=xerr,fmt='.',capsize = 0,linewidth = 0.1,markersize = 0.5)
    show()
    
    plot(x,T,'.',markersize = 0.5)
    plot([0,0],[60,60],'-')
    xlabel('T spectrometer [eV]')
    ylabel('T spitzer [eV]')
    ylim(0,60)
    xlim(0,60)
    show()
    
    errorbar(arange(n),T,yerr=Terr,capsize = 0,linewidth = 0.1,markersize = 0.5)

 

    show()
    
print '-------------------'

TotalLqr(f, T_spitzer,T_spitzer_err,T,errT)





#1.21772166429



































