Staff/UndergradStudents/archive/MarHetf/Scintillator_calibration.py

# -*- coding: utf-8 -*-

#########################################
#                                       #
#      Scintillator Calibration tool    #
#        Author: Martin Hetflejš        #
#   E-mail: mhetflejs@protonmail.com    #
#           Created: 10/2017            #  
#        Last update: 15.11.2017        # 
#                                       #
#########################################

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit, OptimizeWarning
from scipy import stats
from astropy.convolution import convolve, Gaussian1DKernel
#import pandas as pd
import subprocess
from os import path, unlink
from time import sleep
import warnings
import tempfile

warnings.simplefilter("error", OptimizeWarning)
plt.rc('font', family='serif')

####################################################################################################
#                                         PULSE DETECTION                                          #
####################################################################################################

def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising', kpsh=False, valley=False, show=False, ax=None):
    x = np.atleast_1d(x).astype('float64')
    if x.size < 3:
        return np.array([], dtype=int)
    if valley:
        x = -x
    # find indices of all peaks
    dx = x[1:] - x[:-1]
    # handle NaN's
    indnan = np.where(np.isnan(x))[0]
    if indnan.size:
        x[indnan] = np.inf
        dx[np.where(np.isnan(dx))[0]] = np.inf
    ine, ire, ife = np.array([[], [], []], dtype=int)
    if not edge:
        ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
    else:
        if edge.lower() in ['rising', 'both']:
            ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
        if edge.lower() in ['falling', 'both']:
            ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
    ind = np.unique(np.hstack((ine, ire, ife)))
    # handle NaN's
    if ind.size and indnan.size:
        # NaN's and values close to NaN's cannot be peaks
        ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan - 1, indnan + 1))), invert=True)]
    # first and last values of x cannot be peaks
    if ind.size and ind[0] == 0:
        ind = ind[1:]
    if ind.size and ind[-1] == x.size - 1:
        ind = ind[:-1]
    # remove peaks < minimum peak height
    if ind.size and mph is not None:
        ind = ind[x[ind] >= mph]
    # remove peaks - neighbors < threshold
    if ind.size and threshold > 0:
        dx = np.min(np.vstack([x[ind] - x[ind - 1], x[ind] - x[ind + 1]]), axis=0)
        ind = np.delete(ind, np.where(dx < threshold)[0])
    # detect small peaks closer than minimum peak distance
    if ind.size and mpd > 1:
        ind = ind[np.argsort(x[ind])][::-1]  # sort ind by peak height
        idel = np.zeros(ind.size, dtype=bool)
        for i in range(ind.size):
            if not idel[i]:
                # keep peaks with the same height if kpsh is True
                idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
                              & (x[ind[i]] > x[ind] if kpsh else True)
                idel[i] = 0  # Keep current peak
        # remove the small peaks and sort back the indices by their occurrence
        ind = np.sort(ind[~idel])

    if show:
        if indnan.size:
            x[indnan] = np.nan
        if valley:
            x = -x
        _plot(x, mph, mpd, threshold, edge, valley, ax, ind)

    return ind

def _plot(x, mph, mpd, threshold, edge, valley, ax, ind):
    """Plot results of the detect_peaks function, see its help."""
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        print('matplotlib is not available.')
    else:
        if ax is None:
            _, ax = plt.subplots(1, 1, figsize=(8, 4))

        ax.plot(x, 'b', lw=1)
        if ind.size:
            label = 'valley' if valley else 'peak'
            label = label + 's' if ind.size > 1 else label
            ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8,
                    label='%d %s' % (ind.size, label))
            ax.legend(loc='best', framealpha=.5, numpoints=1)
        ax.set_xlim(-.02 * x.size, x.size * 1.02 - 1)
        ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max()
        yrange = ymax - ymin if ymax > ymin else 1
        ax.set_ylim(ymin - 0.1 * yrange, ymax + 0.1 * yrange)
        ax.set_xlabel('Data #', fontsize=14)
        ax.set_ylabel('Amplitude', fontsize=14)
        mode = 'Valley detection' if valley else 'Peak detection'
        ax.set_title("%s (mph=%s, mpd=%d, threshold=%s, edge='%s')"
                     % (mode, str(mph), mpd, str(threshold), edge))
        plt.grid()
        plt.show()

####################################################################################################
#                                            ISF READER                                            #
####################################################################################################

def getfloat(string,pattern):
    ii = string.find(bytes(pattern, 'utf-8')) + len(pattern)
    z = string[ii:].split(bytes(';', 'utf-8'))[0]
    return float(z)
    
def getint(string,pattern):
    ii = string.find(bytes(pattern, 'utf-8')) + len(pattern)
    z = string[ii:].split(bytes(';', 'utf-8'))[0]

    return int(z)
    
def getstr(string,pattern):
    ii = string.find(bytes(pattern, 'utf-8')) + len(pattern) + 1
    z = string[ii:].split(bytes(';', 'utf-8'))[0]
    return z

def getquotedstr(string,pattern):
    ii = string.find(bytes(pattern, 'utf-8')) + len(pattern) + 1
    z = string[ii:].split(bytes('"', 'utf-8'))[1]
    return z

def isf_read(filename):
    with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tempin:
        temp_path = tempin.name
        print(temp_path)
        tempin.write(filename)
    if not path.exists(temp_path):
        print("FAILED")
        print("Missing isf data")
        exit()
    if path.getsize(temp_path) == 0:
        print("FAILED")
        print("Empty isf data")
        exit()
    #f = open('e:/Disk Google/Python/test.isf', "rb")
    f = open(temp_path, "rb")
    offset = 0
    #offset
    f.seek(offset,0)

    hdata = f.read(511)

    class HeadFile:
        def __init__(self):
            pass
    
    head = HeadFile()
    
    head.byte_num = getint(hdata,'BYT_N')
    head.bit_num = getint(hdata,'BIT_N')
    head.encoding = getstr(hdata,'ENC')
    head.bin_format = getstr(hdata,'BN_F')
    head.byte_order = getstr(hdata,'BYT_O')
    head.setting = getquotedstr(hdata,'WFI')
    head.point_format = getstr(hdata,'PT_F')
    head.x_unit = getquotedstr(hdata,'XUN')
    head.x_zero = getfloat(hdata,'XZE')
    head.x_increment = getfloat(hdata,'XIN')
    head.pt_off = getfloat(hdata,'PT_O')
    head.y_unit = getquotedstr(hdata,'YUN')
    head.y_multipt_const = getfloat(hdata,'YMU')
    head.y_zero = getfloat(hdata,'YZE')
    head.y_offset = getfloat(hdata,'YOF')
    head.n_samples = getint(hdata,'NR_P')
    head.vscale = getfloat(hdata,'VSCALE')
    head.hscale = getfloat(hdata,'HSCALE')
    head.vpos = getfloat(hdata,'VPOS')
    head.voffset = getfloat(hdata,'VOFFSET')
    head.hdelay = getfloat(hdata,'HDELAY')
    

    if  head.encoding != b'BIN' or head.bin_format != b'RI' or head.point_format != b'Y':
        f.close()
        unlink(temp_path)
        assert not path.exists(temp_path)
        print('Unable to process IFS file.')

    if head.byte_order == b'MSB': # Big Endian encoding
        machineformat = '<'
    
    elif head.byte_order == b'LSB':  # little-endian encoding 
        machineformat = '>'

    else:
        print('Unrecognized byte order.')

    
    ii = hdata.find(bytes('#', 'utf-8'))

    f.seek(ii+1+offset,0)
    skip= int(f.read(1))+1
    f.seek(skip,1)
    
    data_type = machineformat +'i'+ str(head.byte_num)  
    
    data = np.fromstring(f.read(), dtype= data_type,count = head.n_samples )

    f.close()
    unlink(temp_path)
    assert not path.exists(temp_path)
    
    v = head.y_zero + head.y_multipt_const*(data - head.y_offset)    
    t = head.x_zero + head.x_increment*np.arange(head.n_samples)

    return np.single(v), np.single(t), head

def isf2array(filename,  compression = None):
  v,t,h = isf_read(filename, compression)
  return np.vstack([t,v]).T

####################################################################################################
#                                          CALIBRATION  TOOL                                       #
####################################################################################################

def gauss(x, *p): # Fit functions for Cs137
    A, mu, sigma = p
    return A*np.exp(-(x-mu)**2/(2.*sigma**2))

def gauss_cobalt(x, *p): # Fit function for Co60
    A, mu, sigma, B, nu, psi = p
    return A*np.exp(-(x-mu)**2/(2.*sigma**2)) + B*np.exp(-(x-nu)**2/(2.*psi**2))

def oscilloscope_settings(stdout_type=subprocess.DEVNULL): # Initial settings for the oscilloscope
    #return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=SAVE:SETUP 5"', stdout_type).returncode # Save current settings
    #if return_code != 0:
    #    print('    FAILED')
    #    exit()
    #return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=:FPANEL:PRESS+menuoff"', stdout_type).returncode # Turn off initial menu message
    #if return_code != 0:
    #    print('    FAILED')
    #    exit() 
    return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=hor:scale 20e-3"', stdout_type).returncode # Set time scale
    if return_code != 0:
        print('    FAILED')
        exit() 
    #return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=ch4:scale 25e-3"', stdout_type).returncode # Set voltage scale
    #if return_code != 0:
    #    print('    FAILED')
    #    exit() 
    return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=ch4:pos -3.4"', stdout_type).returncode # Set vertical position
    if return_code != 0:
        print('    FAILED')
        exit() 
    return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=save:waveform:fileformat internal"', stdout_type).returncode # Set to download data in binary
    if return_code != 0:
        print('    FAILED')
        exit() 
    return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=select:ch4 on;ch1 off;ch2 off;ch3 off"', stdout_type).returncode # Select channel to send data
    if return_code != 0:
        print('    FAILED')
        exit()
    return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=select:ch4:invert on"', stdout_type).returncode # Invert channel polarity (only display)
    if return_code != 0:
        print('    FAILED')
        exit()
    return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=select:ch4:impedance 50"', stdout_type).returncode # Set impedance to 50Ohm
    if return_code != 0:
        print('    FAILED')
        exit()
    return_code = command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=hor:recordlength 5000000"', stdout_type).returncode # Set max samples
    if return_code != 0:
        print('    FAILED')
        exit()
    sleep(1)
        
def command(command, stdout_type=subprocess.PIPE): # Sending commands to bash or cmd
    return subprocess.run(command, shell = True, stdout=stdout_type, universal_newlines=True)

def get_data(data_name=None, debug=True, acq_count=1, threshold=0): # Data download, find pulse heights
    peaks_array = []
    command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"command=CH4:Autoset EXECute"')
    voltrange = command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"command=CH4:SCAle?"').stdout
    command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"command=CH4:Autoset Undo"')
    print(float(voltrange.split()[0]))
    command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"command=CH4:scale {}"'.fotmat(float(voltrange.split()[0])))

    #BASELINE TREATEMENT HERE

    for acq in range(acq_count): # Number of acquisitions
        print(acq)
        command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"command=:FPANEL:PRESS FORCETRIG"')
        #command('wget -t1 -q -O - http://192.168.2.40/download.cgi?"command=:FPANELPRESS RUNSTOP"')
        command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"command=:FPANEL:PRESS SINGLESEQ"')
        command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"command=select:ch4 on"')
        command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"command=save:waveform:fileformat internal"')
        isf_data = command('wget -T2 -t1 -q -O - http://192.168.2.40/download.cgi?"wfmsend=Get"').stdout
        #data = -pd.read_table('c:/Users/Marti/20ms15mVCs_1.csv', memory_map=True, delimiter=',', skiprows=14).values[:,1].ravel()
        data,t,h = isf_read(isf_data) #c:/Users/Marti/20ms30mVCs_4.isf

        data = [x for x in data if x > threshold]

        if debug is True:
            peaks_index_array = detect_peaks(data, mph=0.0025, mpd=35, show=True)
        else:
            peaks_index_array = detect_peaks(data, mph=0.0025, mpd=35, show=False)

        if len(peaks_index_array) < 5000:
            for index in peaks_index_array:
                peaks_array.append(data[index])
    return peaks_array

def get_peak_position(data, title_label, debug=True): # Find peak, call fit()
    _ = 0
    smooth_factor = 5
    while _ == 0:
        kerG = Gaussian1DKernel(smooth_factor)
        #print(int(550))
        a, bins, patches = plt.hist(data, int(550), histtype='step')
        plt.close()
        conG = convolve(a, kerG, boundary='extend')
        peaks = detect_peaks(conG, mph=5, mpd=10, show=debug)
        #print(max(peaks))
        if title_label in ('Cs 137', 'Mn 54'):
            if len(peaks) == 1:
                coeff, errors = fit(int(peaks[0]-peaks[0]/3.03), int(peaks[0]+peaks[0]/1.54), data, title_label, [10., 0.6, 0.5])
                _ = 1
                return coeff, errors
            elif smooth_factor > 10 and len(peaks) < 5 and len(peaks) > 0:
                #coeff, errors = fit(int(peaks[-1]-peaks[-1]/5.03), int(peaks[-1]+peaks[-1]/1.54), data, title_label)
                coeff, errors = fit(int(max(peaks)-max(peaks)/5.03), int(max(peaks)+max(peaks)/1.54), data, title_label, [10., 0.6, 0.5])
                _ = 1
                return coeff, errors
            elif smooth_factor > 300:
                print('\nCouldn\'t find the photopeak. Decrease the number of bins. Signal plot is shown for debug purpose.')
                peaks = detect_peaks(conG, mph=0.01, mpd=35, show=True)
                exit()
            else:
                smooth_factor += 2
                #print(smooth_factor)
        elif title_label == 'Na 22':
            if len(peaks) == 2:
                coeff, errors = fit(int(peaks[0]-10), int(peaks[0]+15), data, title_label, [10., 0.6, 0.5])
                coeff2, errors2 = fit(int(peaks[-1]-20), int(peaks[-1]+25), data, title_label, [10., 0.9, 0.5])
                _ = 1
                return coeff, errors, coeff2, errors2
            elif smooth_factor > 5 and len(peaks) < 5 and len(peaks) > 0:
                coeff, errors = fit(int(peaks[0]-10), int(peaks[0]+15), data, title_label, [10., 0.6, 0.5])
                coeff2, errors2 = fit(int(peaks[-1]-20), int(peaks[-1]+25), data, title_label, [10., 0.9, 0.5])
                _ = 1
                return coeff, errors, coeff2, errors2
            elif smooth_factor > 300:
                print('\nCouldn\'t find the photopeak. Decrease the number of bins. Signal plot is shown for debug purpose.')
                peaks = detect_peaks(conG, mph=0.01, mpd=35, show=True)
                exit()
            else:
                smooth_factor += 1
                #print(smooth_factor)
        else:
            if len(peaks) == 1:
                coeff, errors = fit(int(peaks[0]-peaks[0]/3.03), int(peaks[0]+peaks[0]/1.54), data, title_label)
                _ = 1
            elif smooth_factor > 10 and len(peaks) < 5 and len(peaks) > 0:
                #coeff, errors = fit(int(peaks[-1]-peaks[-1]/5.03), int(peaks[-1]+peaks[-1]/1.54), data, title_label)
                coeff, errors = fit(int(max(peaks)-max(peaks)/5.03), int(max(peaks)+max(peaks)/1.54), data, title_label)
                _ = 1
            elif smooth_factor > 300:
                print('\nCouldn\'t find the photopeak. Decrease the number of bins. Signal plot is shown for debug purpose.')
                peaks = detect_peaks(conG, mph=0.01, mpd=35, show=True)
                exit()
            else:
                smooth_factor += 2
                #print(smooth_factor)
    #return coeff, errors

def fit(start, stop, data, title_label, p0): # Fit and plot data
    nonz = []
    bin_centers_cut_nonz = []
    hist, bin_edges = np.histogram(data, bins=int(550))    
    hist_cut = hist[start:stop]
    for i in range (len(hist_cut)):
        if hist_cut[i] != 0:
            nonz.append(i)
    hist_cut = [x for x in hist_cut if x!=0]
    bin_centres = (bin_edges[:-1] + bin_edges[1:])/2
    bin_centres_cut = bin_centres[start:stop]
    for j in nonz:
        bin_centers_cut_nonz.append(bin_centres_cut[j])
    if title_label in ('Na 22','Cs 137', 'Mn 54'):
        #p0 = [10., 0.6, 0.5]
        try:
            coeff, var_matrix = curve_fit(gauss, bin_centers_cut_nonz, hist_cut, p0=p0, maxfev=1000000)
        except OptimizeWarning:
            print("FAILED")
            print("Covariance of the parameters could not be estimated...")
            exit()
        hist_fit = gauss(bin_centres_cut, *coeff)
    elif title_label == 'Co 60':
        p0 = [10., 0.10, 0.5, 8., 0.12, 0.5]
        try:
            coeff, var_matrix = curve_fit(gauss_cobalt, bin_centers_cut_nonz, hist_cut, p0=p0, maxfev=1000000)
        except OptimizeWarning:
            print("FAILED")
            print("Covariance of the parameters could not be estimated...")
            exit()
        hist_fit = gauss_cobalt(bin_centres_cut, *coeff)
    errors = np.sqrt(np.diag(var_matrix))
    
    plt.figure(1)
    #plt.clf()
    plt.title(title_label, fontsize=16)
    plt.xlabel('U [mV]')
    plt.ylabel('N [-]')
    n, bins, patch = plt.hist(data, int(550), histtype='step')
    plt.plot(bin_centres_cut, hist_fit, 'r', label='Fitted data')
    plt.grid(True)
    print(coeff, errors)
    return coeff, errors


if __name__ == "__main__":
    print("Scintillator calibration tool\n__________________________________\n")
    print('Posible elements: (Cs, Na, Mn)\n')
    elements = input('Enter elements:').title().split()
    
    print('\nConnecting to the oscilloscope...    ', end=" ", flush=True)
    #is_online = command('ping 192.168.2.40 -c4').returncode # Linux
    is_online = command('ping 192.168.2.40').returncode # Windows
    is_online = 0
    if is_online != 0:
        print('FAILED')
        exit()
    else:
        print('DONE')
        print('Setting up the oscilloscope...       ', end=" ", flush=True)
        #oscilloscope_settings()
        print('DONE')
        acq_count_set = 10
        for element in elements:
            print("\nPut {} capsule on the scintillator...".format(element))
            print('Press ENTER to continue...', end=" ", flush=True)
            input("\033[F")
            print("Acquiring data...                    ", end=" ", flush=True)
            data = get_data(acq_count=acq_count_set)
            print("DONE")
            print('Looking for peak and fitting data...     ', end=" ", flush=True)
            if element == 'Cs':
                coeff_cs, errors_cs = get_peak_position(data, 'Cs 137')
            elif element == 'Mn':
                coeff_mn, errors_mn = get_peak_position(data, 'Mn 54')
            elif element == 'Na':
                coeff_na1, errors_na1, coeff_na2, errors_na2 = get_peak_position(data, 'Na 22')
            elif element not in ('Mn', 'Cs', 'Na'):
                print('Element {} was not recognized'.format(element))
            print("DONE")
            plt.show()

    keV_dict = {'Cs': 661.7,'Mn': 835,'Na1': 511, 'Na2': 1274.5}
    keV = []
    volts = []
    errs = []
    if coeff_cs in locals():
        volts.append(coeff_cs[1])
        errs.append(errors_cs[1])
        keV.append(keV_dict['Cs'])
    if coeff_mn in locals():
        volts.append(coeff_mn[1])
        errs.append(errors_mn[1])
        keV.append(keV_dict['Mn'])
    if coeff_na1 in locals():
        volts.append(coeff_na1[1])
        errs.append(errors_na1[1])
        keV.append(keV_dict['Na1'])
    if coeff_na2 in locals():
        volts.append(coeff_na2[1])
        errs.append(errors_na2[1])
        keV.append(keV_dict['Na2'])

    plt.figure(3)
    plt.plot(keV, volts, '.')
    def linereg(x,*p):
        a, b = p
        return a*x+b
    lin_coeff, lin_var_matrix = curve_fit(linereg, keV, volts, [1., 1.], sigma=errs, maxfev=1000000)
    y = linereg(np.arange(400,1300,.01), lin_coeff[0], lin_coeff[1])
    plt.plot(np.arange(400,1300,.01), y)
    print('Calibration constants are {} and {}'.format(lin_coeff[0], lin_coeff[1]))
    plt.show()


#TODO
#autoset? (autoset>ver:scale 20ms>invert>trigger?)/(autoset>get ver:scale>undo autoset>set ver:scale) ✓
#kalibrační křivka ✓
#background level (sebrat prázdná data>histogram>max je baseline)(mph v detect_peaks)
#auto bin size
#pokud selže find_peak > změnit bin size a zkusit znovu