
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys, getopt, os
from numpy import *
from numpy.matlib import repmat
from matplotlib.pyplot import *
from sklearn import svm, grid_search, metrics
from sklearn.cross_validation import StratifiedKFold
from copy import deepcopy
from scipy.stats.mstats import mquantiles
import re
from data_object import Data


# possible agruments

ARGS = array([


Example of use

./ --train --plot
./ --test --Ub=400 --Ucd=500  --pressure=27.84  --Tcd=0.01
./ --outliers


def use():
    print 'availible parameters:'
    print ARGS
    print "./ --train --plot"
    print "./ --test --Ub=400 --Ucd=500  --pressure=27.84  --Tcd=0.01"
    print "./ --outliers"
    #--gas_filling=1 --Ubd=100 --Ust=0   --Tbd=0.01 --Tst=0 --PreIonization=1

def main():
    """  load user setting from command line """

    params = dict()
    for i in ARGS:
	params[i] = ""

    vals=[(i+'=') for i in ARGS]
    for i in  ['train', 'test', 'outliers', 'plot', 'help']:

	opts, args = getopt.getopt(sys.argv[1:],   "h", vals )
    except Exception, err:
	# print help information and exit:
	print str(err)

    for o, a in opts:
	for i in range(len(ARGS)):
	    if o  in ('--'+ARGS[i]):
		params[ARGS[i]] = a

    for o, a in opts:
	if o in ("--help"):
	if o  in ("--train"):
	    data_train  = load_data('data_breakdown.npz', balance_classed = True)
	    data_test  = load_data('data_breakdown.npz', balance_classed = False)
	    train(data_train,data_test )
	elif  o  in ("--test"):
	    machine = load('model.npy').item()
	    data  = load_data('data_breakdown.npz')

	    if 'gas_filling' in params  and int(params['gas_filling']) == 0:
		print "probability ",0
		return 0

	    X =  array([ params[name] for name in data.names ]) # ignore the first value (gas_filling)
	    if X.dtype != float:
		X[X == ""] = data.get_norm()[0,X == ""]   # replace missing values by median in the dimension
	    X  =reshape(double(X), (1,-1))
	    print "testing"

	    proba = test(X, data, machine)[0]

	elif  o in ("--outliers"):
	    machine = load('model.npy').item()
	    data  = load_data('data_breakdown.npz')
	    outliers(data, machine)
	elif  o in ("--plot"):
	    machine = load('model.npy').item()
	    data  = load_data('data_breakdown.npz')
	    plot_data(data, machine)

    if len(opts) == 0:

def load_data(path, balance_classed = True):
    """ load the data from disk and remove problematic shots """

    d = load(path)
    shots = d['shots']
    data_dict = d['data'].item()
    Y = data_dict['plasma']
    Ucd =  data_dict['Ucd']
    # use only keys allowed in ARGS
    remove_keys = array(data_dict.keys())[~in1d(data_dict.keys(), ARGS)]
    keys = [ i for i in data_dict.keys() if i not in remove_keys ]
    X = array([ data_dict[i] for i in keys ]).T
    N = len(X)

    # remove damaged  signal
    ind = array( [(True if re.match( 'OK',  data_dict['plasma_status'][i]) else False) for i in range(N)] )
    ind &= ~isnan(Y)  & ~any(isnan(X) , axis = 1)
    # remove outling points (mistakes)
    ind &=  all(abs(X) <=  mquantiles(abs(X[ind,:]), 0.99, axis=0), axis=1)
    # remove outling points (mistakes)
    ind &= (data_dict['loop_voltage_max'] > 5) & (data_dict['loop_voltage_max'] < 40)
    # remove vacuum shots
    ind &= data_dict['gas_filling'] == 1
    # too low Ucd or Ubd => again some error
    ind &= (data_dict['Ucd'] > 10) & (data_dict['Ub'] > 10)
    # too long plasma
    ind &= (data_dict['plasma_life'] < 17e-3 ) | isnan(data_dict['plasma_life'])
    # too saturated trafo => probably false detect
    ind &= ((data_dict['transformator_saturation'] < 0.8) & (data_dict['plasma'] == 1)) \
	    | (data_dict['plasma'] != 1) | isnan(data_dict['transformator_saturation'])
    # remove problematic sessions
    ind &= array( [ False if re.match('^Technological/(DAS|Technological|Problems|Software|Repairs)',
		    data_dict['session_name'][i]) else True  for i in range(N)] )

    data = Data(X[ind,:],Y[ind],shots[ind], keys , [])

    # remove too many succefull breakdowns => now there should by only 3x more breakdowns than failures !!
    if balance_classed:
	ind = (data.Y == 1)  &  (mod( arange(data.N) , int((0.3/(1-mean(data.Y)))))  != 0)

    return data

def artif_data(data):
    """ add artificial data to the clear areas without breakdown => apriory knowledge """
    Nvals = 100  # number of the artificial points
    Ndim = data.Ndim
    MIN = zeros(Ndim)
    MAX = zeros(Ndim)
    for i in range(Ndim):
	MIN[i] = mquantiles(data.X[:,i], 0.01)
	MAX[i] = mquantiles(data.X[:,i], 0.99)

    def random_data():
	R = random.rand(Nvals, Ndim)
	R *= (MAX-MIN)
	R += MIN
	return R
    RAN_B = random_data()  
    RAN_B[:,data.ind('Bfield')] = 0
    RAN_P_min = random_data()
    # no filling gas 
    RAN_P_min[:,data.ind('pressure')] = random.random(Nvals)*1
    RAN_CD = random_data()
    RAN_CD[:,data.ind('Ucd')] = random.random(Nvals)*10
    X = vstack([RAN_B, RAN_P_min, RAN_CD])
    Y = zeros(len(X)) # no breakdown

    return data

def prepare_data(data, allow_artif_data):
    """ add some more apriory knowledge, => remove strange shots, recalculate the variables (approximately), remove unomportant dimensions, ... """
    ShotNo = data.shots
    names = data.names

    #data.X[:,data.ind('Tcd')] = data['Tbd']-data['Tcd']   #the most important is the difference CD a BD

    # perfrom the logaritmic substitution !! => less than 2.5 is useless (vaccum)
    ind = data.ind('pressure')
    data.X[data['pressure'] < 2.5, ind] = 3
    data.X[:,ind] = log(data['pressure']-2)   #tlak

    #==============  relative calibration of the capacitors for  different shots  =================== % 0.02 0.08
    C = ones(data.N) * 3.9
	C[(ShotNo <= 1468) ] = 0.6
	C[((ShotNo > 1468) & (ShotNo<= 2876)) ] = 2
	C[((ShotNo > 2876) & (ShotNo<= 2918)) ] = 3
	C[((ShotNo > 2918) & (ShotNo<= 3305)) ] = 4.2
	C[((ShotNo > 3305) & (ShotNo<= 9836)) ] = 3.9
	C[((ShotNo > 3305) & (ShotNo<= 9836)) ] = 3.9
	C[(ShotNo > 9836) ] = 16


    # approximate the magnetic field in the time of maximal breakdown field
    data.X[:,data.ind('Ub')] = data['Ub'] *sqrt(C)* sin(255*1e-6*1/sqrt(C) * data['Tcd'])  
    data.names[data.ind('Ub')] = 'Bfield'    #  renames the variable ...

    if allow_artif_data:   # create artificial data in certain areas
	data = artif_data(data)

    if len(data.norm) == 0:
	data.norm = data.get_norm()
	data.norm[[1,2], data.ind('Bfield')] *= 3  # allow fasted changes in this dimension
    return data

def train(data, data_test):
    Call SVM algorithm to make probability prediction of breakdown
    print 'Learning'

    data = prepare_data(data ,True)
    data_test = prepare_data(data_test ,False)

    # range of the training prameters for SVM 
    crange = linspace(10,15,4)
    grange = linspace(-15,-7,4)

    parameters = {'C':2.0**crange, 'gamma':2.0**grange}

    print 'start grid'
    print "N points", data.N
    print "N dim", data.Ndim

    #print data.X, data.Y
    machine = svm.SVC(kernel = 'rbf', class_weight ="auto",
    cache_size = 800, verbose=False, tol=1e-2 , shrinking=True ) # gaussion kernel
    machine = grid_search.GridSearchCV( machine , parameters, n_jobs=-1)
    print "Solving ...." data.X, data.Y, cv=StratifiedKFold(data.Y, 5))   # find the optimal parameters 
    #save('model_all', machine)
    best = svm.SVC(probability=True,C = machine.best_estimator_.C , kernel = 'rbf', gamma = machine.best_estimator_.gamma)

    print  "=====BEST PARAMETERS========= "
    print "C", machine.best_estimator_.C
    print "gamma", machine.best_estimator_.gamma
    print "==============================", data.Y)

    ####################x report #######################
    print 'all data'
    y_pred =  best.predict(data_test.X)  
    print metrics.classification_report(data_test.Y, y_pred)

    save('model', best)
    # percents of wrong predict 
    print "zero-one loss function - train set - %2.1f%%" % (mean(data.Y != best.predict(data.X))*100)
    print "zero-one loss function - test set  - %2.1f%%" % (mean(data_test.Y != best.predict(data_test.X))*100)
    print 'number of  SV   ', shape(best.support_vectors_)

def test(X, data, machine):
     get  breakdown probability for the input vector X
     class data is used only for normalization, machine contains class of the SVM learning predictor
    names = data.names
    data = prepare_data(deepcopy(data) ,False)  # preprocess data => to get norm of final dataset
    data_new = Data(X,[nan],[nan], names, data.norm)
    data_new = prepare_data(data_new ,False)   # preprocess userselected   variables with the same norm 

    print 'predict'
    Y =  machine.predict_proba(data_new.X)
    dist = machine.decision_function(data_new.X)
    probability = int(Y[0,1]*100)
    print "breakdown ", bool(machine.predict(data_new.X)),'<= probability of breakdown' , probability, '%'
    #print "distance from boundary", double(dist)

    return probability, dist 

def outliers(data, machine):
    data = prepare_data(data ,False)
    proba =  machine.predict_proba(data.X)[:,1]
    print 'expected breakdown but failed    #shots:' 
    print int_(data.shots[(proba > 0.90)  & (data.Y == 0)])
    print 'unexpected breakdown  #shots:'
    print int_(data.shots[(proba < 0.2)  & (data.Y == 1)])

def plot_data(data, machine):
    Plot 2-dimensinal cut through 4-imensional space 
    med_orig = median(data.X,0)

    vars = [ 'pressure', 'Ucd','Ub' ]
    const = ARGS[~in1d(vars, ARGS)]  # the rest of variables will lbe constant
    for i in linspace(amin(data[vars[2]]), amax(data[vars[2]]), 20):
	plotting(data,const,vars, i, machine )

def plotting(data, const,vars,zvar, machine):
    """ plot cuts through the mutlidimensional space
	vars are names of variables in x,y,z axis 
	zvar is value of variable in Z axis
	other variables (hidden) are selected to be equal to mean values
	machine - machine learning object
	const = variable 
    print "plotting", vars, zvar

    max_distance = 0.5 # number changing maximal distance to plot the points 
    X = deepcopy(data.X)
    MAD = mean(abs(X - mean(X,0)),0)  # mean absolute deviation
    MEAN = mean(data.X,0)
    MEAN[data.ind(vars[2])] = zvar  # 
    Xaxe = linspace(amin(data[vars[0]]), amax(data[vars[0]]), 100)
    Yaxe = linspace(amin(data[vars[1]]), amax(data[vars[1]]), 100)

    xx, yy = meshgrid(Xaxe, Yaxe)

    data_grid = deepcopy(data)
    data_norm = prepare_data( deepcopy(data) ,False)  # data need to be normalized to the same scale as trainning set

    N_grid =  size(xx)
    data_grid.X = repmat(MEAN,N_grid,1)
    data_grid.Y = ones(N_grid)*nan
    data_grid.shots = ones(N_grid)*nan
    data_grid.N = N_grid
    data_grid.norm = data_norm.norm
    # inplace variable in X and Y axis !! 
    data_grid.X[:, data_grid.ind(vars[0])] = xx.ravel()  
    data_grid.X[:, data_grid.ind(vars[1])] = yy.ravel()  

    data_grid = prepare_data(data_grid, False)

    proba = machine.predict_proba(data_grid.X)[:,1]

    proba = proba.reshape(xx.shape)

    N = data.N
    Ndim = data_grid.Ndim

    # find shots close to the cuts
    ind = bool_(N)
    const_dim = where(in1d(data.names, const) |  in1d(data.names, [vars[2]]))[0]
    for i in const_dim:
	ind = ind*((abs(data.X[:,i] - MEAN[i])/MAD[i])<max_distance)

    sim_data = deepcopy(data)

    # plot the probability 
    proba[0,0] = 1
    proba[-1,-1] = 0
    contourf(xx, yy, proba, 200)
    cb = colorbar()'Probability of breakdown')

    contour(xx, yy, proba-0.8, 1, linestyles='dotted')
    contour(xx, yy, proba-0.2, 1, linestyles='dotted')
    contour(xx, yy, proba-0.5, 1, linewidths=2)

    axis([amin(Xaxe), amax(Xaxe), amin(Yaxe), amax(Yaxe)])

    if sim_data.N > 0:
	X = sim_data.X[:, [data.ind(vars[0]),data.ind(vars[1])]]
	i = 0
	for i in range(2):
	    X[:,i] += 0.005*(amax(X[:,i]) - amin(X[:,i]))*squeeze(random.randn(sim_data.N,1))
	Y = sim_data.Y
	Y = Y[:,None].repeat(3, axis=1)  # represents now RGB color
	scatter(X[:,0], X[:,1], 20 , Y, edgecolors='none', alpha=0.8 )
    title('Probability prediction of breakdown    %s = %g' % (vars[2], zvar) )

    savefig('graf_%s_%05d.png' %  (vars[2], zvar))

if __name__ == "__main__":