#!/usr/bin/python2 # -*- coding: utf-8 -*- """ CREATED: 7/2012 AUTHOR: MICHAL ODSTRCIL """ from numpy import * from config import * import sys #try: #from matplotlib.pyplot import * #except: #pass def import_matplotlib(): import matplotlib matplotlib.rcParams['backend'] = 'Agg' matplotlib.rc('font', size='10') matplotlib.rc('text', usetex=True) # FIXME !! nicer but slower !!! import matplotlib.pyplot as plt return plt try: plt = import_matplotlib() except: pass import os, re, datetime, time from pygolem_lite import * from copy import deepcopy from pygolem_lite import isnone ###################### SIMPLE DATA LOADING ############################## def cat(path, lines = [0], return_array = False): """ read file from `path` and return string. If there is several lines, it can return array if `return_array == True` """ #try: with open(path, 'r') as f: content_tmp = f.readlines() if return_array: return content_tmp if len(lines) == 0: lines = range(len(content_tmp)) content = "" for i in lines: content += content_tmp[i] #except: #content = 'N/A' return content #return the raw string ####################### OTHER USEFUL ROUTINES #################################x def nanmedian(data): return median(data[~isnan(data)]) def list2array(L): """ convert list output from pygolem to array / scalar """ if type(L) is not list and type(L) is not tuple : return L if ndim(L[1]) == 1: dim = 1 else: dim = size(L[1],1) out = empty([len(L[0]), dim+1]) out[:,0] = L[0] out[:,1:] = reshape(L[1], (-1, dim)) return out def random_number_downsample(tvec, data, N): """ downsample data without moire effects """ Ndata = len(tvec) if Ndata <= N: return tvec, data ind = int_( linspace(0,Ndata-1, N)) ind_rand = ind + (random.randint(0,Ndata/N, size = (N)) if Ndata/N > 0 else 0) ind_rand[-1] = Ndata-1 ind_rand = unique(ind_rand) return tvec[ind_rand], data[ind_rand,:] ###################### PLOTTING ############################## def multiplot(data_all, title_name, file_name, figsize = (9,None), dpi = 100, orientation = 'vertical', file_type = 'png', reduction = False, debug_info = True): """ Plot several graphs to on multiplot. Use data `data_all` provided by function get_data !! :var dict data_all: Input data or list of input data :var str title_name: Title of the image :var str file_name: Name of saved image (without filetype) :var list figsize: Size of the output image :var int dpi: DPI of output image :var str orientation: Choose multiplot orientation: vertical/horizontal :var str file_type: Image file type, recomended: png, svgz !! svgz is slow !! :var bool reduction: Decrease number of plotted points => speed up plotting, use for all nice smooth lines !! :var bool debug_info: Allow printing of some debug info """ from matplotlib.ticker import MultipleLocator, FormatStrFormatter,LogLocator, MaxNLocator plt = import_matplotlib() class MyFormatter(plt.ScalarFormatter): # improved format of axis def __call__(self, x, pos=None): self.set_scientific(True) self.set_useOffset(True) self.set_powerlimits((-2, 3)) #print x, pos, ScalarFormatter, ScalarFormatter() if pos==0: return '' else: return plt.ScalarFormatter.__call__(self, x, pos) is_icon = dpi <= 50 if all(isnone(data_all)): #print "data_all", data_all raise IOError('Empty input data set for plotting') else: data_all = deepcopy(data_all) # avoid mysterious changes in plots if type(data_all) is dict: data_all = [data_all] i = 0 # !! remove empty fields !! while i < len(data_all): if len(data_all[i]) == 0 or all(isnone( data_all[i] )) : data_all.pop(i) else: i += 1 n_plots = len(data_all) t0 = time.time() if figsize[1] is None: # if the size is not determined, make it flexible figsize = array(figsize) figsize[1] = 3*n_plots fig = plt.figure(num=None, figsize=figsize, edgecolor='k') plt.subplots_adjust(wspace=0.3,hspace=0) for i,data in enumerate(data_all): if orientation is 'vertical': ax = fig.add_subplot(n_plots,1,i+1) ax.xaxis.set_major_formatter( plt.NullFormatter() ) ax.yaxis.set_major_formatter( MyFormatter() ) # improved format of axis else: ax = fig.add_subplot(1,n_plots,i+1) if is_icon: # very small image reduction = True try: _plot_adv(data, reduction, ax, dpi) except Exception, e: raise NotImplementedError("plotting failed", str(e), "\n" )# , data try: # problem with log scale ax.ticklabel_format(style='plain', axis='y') ax.yaxis.set_major_locator( MaxNLocator(max(3,int(figsize[1]/double(n_plots)*5) ))) except: pass ax.tick_params(axis='both' if not is_icon else 'none',reset=False,which='both',length=2,width=0.5) if i == 0 and title_name != "" and not is_icon: ax.set_title(fix_str(title_name)) if orientation is 'horizontal': plt.xlabel(data['xlabel'] if type(data) == dict else data[0]['xlabel']) if orientation is 'vertical': ax = fig.add_subplot(n_plots ,1,n_plots) ax.xaxis.set_major_formatter( plt.ScalarFormatter()) #ax.yaxis.set_major_formatter( plt.ScalarFormatter()) #ax.xaxis.set_minor_locator( MultipleLocator(1) ) if not is_icon: try: plt.xlabel( fix_str(data['xlabel'] if type(data) == dict else data[0]['xlabel']) ) except Exception, e: print "plotting xlabel failed", str(e), "\n", data t0 = time.time() file_name = re.sub('(.+)\.(.+)$', r'\1', re.sub('(.+)\.(.+)$', r'\1', file_name)) image_name = file_name+'.'+file_type plt.savefig( image_name , dpi= dpi, bbox_inches='tight') # better but slower plt.clf() if file_type in ['png', 'jpg']: os.system('convert '+ image_name +" -trim " + image_name ) if debug_info: print "saving", time.time() - t0 print "plotting time %.2fs" % (time.time() - t0) def paralel_multiplot(*kargs, **kwargs): import multiprocessing try: multiprocessing.Process(target=multiplot, args=kargs, kwargs=kwargs).start() except: print "PLOTTING FAILED:" + kargs[0] def _plot_adv(data, reduction, ax, dpi): """ advanced plotter with dictionary input, internal function, do not use directly """ if type(data) is dict: data = [data] is_icon = dpi <= 50 i = 0 # !! remove empty fields !! while i < len(data): if all(isnone(data[i])) or len(data[i]) == 0: data.pop(i) else: i += 1 if len(data) == 0: # empty list return use_legend = False N_data = len(data) t0 = time.time() for i in range(N_data): if len(data[i]) == 0 or all(isnan(data[i]['data'])): continue d = data[i] kwargs = d['kwargs'] ind_plot = ones(len(d['tvec']), dtype = bool) if kwargs['label'] != "": use_legend = True if 'xlim' in d.keys(): # speed up plotting ind_plot = where( (d['xlim'][0] <= d['tvec']) & ( d['xlim'][1] >= d['tvec'])) ind_plot = concatenate([ind_plot[0]-1, squeeze(ind_plot), ind_plot[-1]+1]) # extend the data by one point from the range ind_plot = ind_plot[(ind_plot >= 0) & (ind_plot < len(d['tvec'])) ] ind_plot = unique(ind_plot) if reduction: ind_plot = ind_plot[:: max(1, len(d['tvec'])/2000) ] # downsampling of errobar data !!! for name in ['xerr','yerr']: if not kwargs[name] is None and not isscalar(kwargs[name]): kwargs[name] = [kwargs[name][ind_plot,0], kwargs[name][ind_plot,1] ] plt.errorbar(d['tvec'][ind_plot], d['data'][ind_plot], linewidth=0.5, capsize=0, **kwargs ) if data[0]['yscale'] != '': ax.set_yscale(data[0]['yscale']) if amax(data[0]['tvec']) - amin(data[0]['tvec']) < 100: ax.xaxis.set_minor_locator(plt.MultipleLocator(1)) if 'vlines' in data[0].keys(): for line in data[0]['vlines']: plt.axvline(x=line, color = 'b', linestyle='--') for key in ['xstart', 'xend']: if key in data[0].keys(): plt.axvline(x=data[0][key], color = 'k', linestyle='--') if 'xlim' in data[0].keys(): Xrange =[min([data[i]['xlim'][0] for i in range(N_data)]),\ max([data[i]['xlim'][1] for i in range(N_data)])] plt.xlim(Xrange) if 'ylim' in data[0].keys(): ymin = inf ; ymax = -inf for i in range(N_data): ymin_tmp, ymax_tmp = _detect_range(data[i]['tvec'] , data[i]['data'], Xrange, data[i]['ylim']) ymin = min(ymin_tmp,ymin) ymax = max(ymax_tmp,ymax) Yrange = [ymin, ymax] plt.ylim(Yrange) if not is_icon: # in case of small icon do no apply if data[0]['ylabel'] != "" and data[0]['ylabel']: plt.ylabel(data[0]['ylabel']) if 'annotate' in data[0].keys(): for annot in data[0]['annotate']: if annot['pos'][0] is None: annot['pos'][0] = Xrange[0] + (Xrange[1] - Xrange[0])*0.05 if annot['pos'][1] is None: annot['pos'][1] = Yrange[1] - (Yrange[1] - Yrange[0])*0.05 plt.annotate(annot['text'], annot['pos'], rotation=annot['angle'], fontsize=annot['fontsize'] ) if 'ygrid' in data[0].keys() and data[0]['ygrid']: ax.yaxis.grid(color='gray', linestyle='dashed') if 'xgrid' in data[0].keys() and data[0]['xgrid']: ax.xaxis.grid(color='gray', linestyle='dashed') if use_legend: #leg = legend() # FIXME !! nicer but slower !!! leg = plt.legend(loc='best' , fancybox=True) leg.get_frame().set_alpha(0.7) return ax def _detect_range(tvec, data, xlim, ylim): """ smart way how to determine best ylim in the plots internal fucntion do not use directly """ from scipy.stats.mstats import mquantiles ind_ok = ~isnan(data) & ~isinf(data) & (tvec > xlim[0]) & (tvec < xlim[1]) yrange = [mquantiles(data[ind_ok], 0.005), mquantiles(data[ind_ok], 0.995)] dy = abs(yrange[1] - yrange[0]) if dy/amax(abs(array(yrange))) < 1e-6: # constant line ... yrange[0] = min(0, yrange[0]*1.1 ) yrange[1] = max(0, yrange[1]*1.1 ) else: yrange[0] -= 0.2*dy yrange[1] += 0.2*dy for i in [0,1]: if ylim[i] is None: ylim[i] = yrange[i] return ylim def get_data(diagn, diagn_name = None, ylabel = None, xlabel = 'Time [ms]', shot_num = None, xlim = [], ylim = [], \ tvec_rescale = 1e3, data_rescale = 1, integrated = False, columns = [], yscale = '', \ reduction = False, plot_limits = True, smoothing = 0, line_format = "", \ vlines = [], annotate = [], xgrid= False, ygrid = False, **kwargs ): """ Load data/diagn and return object that is used for plotting. :var str/list/Data class diagn: Input can be string with name of diagn from main database, name of file with locally saved data on disk, list with [tvec,data] :var str diagn_name: Name of diagnostics that will be used in legend, if diagn has more signal than name will be `name 0`, `name 1`, ...., or diagn_name can be list with name for each column :var str ylabel: Title of ylabel :var str xlabel: Title of xlabel :var list xlim: Range for X axis, use [1.132 ,None] to define only one limit :var list ylim: Range for Y axis, use [1.132 ,None] to define only one limit :var int shot_num: Number of shot, default is autodetect from current folder :var double tvec_rescale: Rescale X axis to ms !! Be careful if the axis is not time in seconds !! :var double data_rescale: Rescale Y axis :var bool reduction: Try to reduce number of plotted points, automatically used for integrated signal :var bool integrated: Integrate the input signal :var list columns: Plot only the selected columns :var str yscale: scale of y-axis, use `log` :var bool plot_limits: If there is plasma, plot lines at start and end of the plasma :var double smoothing: Smooth the signal :var str line_format: Standard formating string :var bool xgrid: Plot automatic xgrid in the plot :var bool ygrid: Plot automatic ygrid in the plot :var array xerr: error in X direction :var array yerr: error in Y direction :var **kwargs : keyword arguments to be passed to the standard plot() command, optional this method accepts **kwargs as a standard :func:`matplotlib.pyplot.plot` function Advanced: :var list vlines : list of positions of vertical lines in plot (see historical analysis) :var list annotate: list of dicts => plt.annotate(annot['text'], annot['pos'], rotation=annot['angle'] ) """ try: # process input setting if line_format != "": kwargs['fmt'] = line_format cData = Shot(shot_num) is_valid = False #return "srs" if type(diagn) is str: # if diagn is_valid = cData._is_valid_name(diagn) if is_valid == 0: # not in database [tvec,data] = load_adv(diagn) # prefere locally saved data elif is_valid in (1,2): # data [tvec,data] = cData[diagn] elif is_valid == 3: # DAS, channel_name [tvec,data] = cData["any", diagn] else: raise NotImplementedError("ambiguous diagn name") elif (type(diagn) is list or type(diagn) is tuple ) and len(diagn) == 2: [tvec,data] = deepcopy(diagn) elif type(diagn) is Data: [tvec, data] = [diagn.tvec, diagn] else: raise NotImplementedError, "Unsupported input" # Errobars =================== if type(data) is Data or not 'yerr' in kwargs or 'xerr' in kwargs: # 'xerr' not user defined in the function call if 'xerr' not in kwargs : kwargs['xerr'] = data.tvec_err if type(data) is Data else None if 'yerr' not in kwargs : kwargs['yerr'] = data.data_err if type(data) is Data else None for name in ['xerr', 'yerr']: err = kwargs[name] if not any(isnone(err)) and ndim(err)==0: err = float(err) if not any(isnone(err)) and not isscalar(err): err = array(squeeze(err), ndmin=2) if size(err,0) == 1: err = err.T if size(err,1) == 1: err = err.repeat(2, axis=1) kwargs[name] = err else: kwargs['xerr'] = None; kwargs['yerr'] = None; # ================================= if smoothing > 0: _, data, _ = DiffFilter(data, mean(diff(tvec)), 2000, 1e6) if len(columns) > 0: data = data[:,columns] data *= data_rescale tvec *= tvec_rescale if not any(isnone(kwargs['yerr'])): kwargs['yerr'] *= data_rescale if not any(isnone(kwargs['xerr'])): kwargs['xerr'] *= tvec_rescale Ncol = size(data,1) if ndim(data) == 2 else 1 data = reshape(data, (-1,Ncol)) # workaround for 1D arrays if integrated: dt = mean(diff(tvec)) data = cumsum(data,0)*dt if reduction or integrated or smoothing > 10: tvec, data = random_number_downsample(tvec, data, 4000) # downsample data to 4000 points # default values plasma = False start = -inf start_pl = nan end_pl = nan end = inf try: plasma = cData['plasma'] except: pass if plasma and plot_limits and tvec_rescale == 1e3: # tvec_rescale == 1e3 => x axis is probably time !!! start_pl = cData['plasma_start']*1e3 end_pl = cData['plasma_end']*1e3 if cData['ucd'] > 10: start = min(start_pl - 2, cData['tcd']*1e3) # 2ms before plasma start else: start = start_pl - 2 end = end_pl + 2 elif tvec_rescale == 1e3: try: start = cData['tb']*1e3-1 end = 35 except: pass t_range = amax(tvec) - amin(tvec) if len(xlim) == 0: xlim = array([max(start, amin(tvec)-t_range*0.05), min(end, amax(tvec)+t_range*0.05)]) if xlim[0] > xlim[1]: # some errror (no intersection) xlim = array([amin(tvec)-t_range*0.05, amax(tvec)+t_range*0.05]) if type(data) is Data and ylabel is None: if not all(isnone(data.ax_labels)): ylabel = data.ax_labels[1] if type(data) is Data and diagn_name is None: diagn_name = data.info output = list() columns = range(Ncol) if len(ylim) == 0: ylim = [None, None] ### ============ auto-labeling ===================== for i in columns: column_name = None if len(diagn_name) == len(columns) and type(diagn_name) is list : # every column has its own label column_name = diagn_name[i] else: if is_valid == 2: # use predefined name in config if it is DAS column_name = cData.channel_name(diagn, i) if column_name is None: # nothing matched ... if len(columns) == 1: # only one variable column_name = diagn_name elif len(diagn_name) > 0: # any name column_name = diagn_name+ ' '+str(i) else: column_name = "" if "label" in kwargs: kwargs.pop( "label" ) ## BUG FIXME !!!! kwargs_tmp = dict(label=fix_str(column_name), **kwargs) o = {'tvec': tvec , 'data':data[:,i], 'ylabel':fix_str( ylabel ), 'xlabel': fix_str(xlabel), 'ylim':ylim, 'xlim':xlim, 'xstart':start_pl, 'xend':end_pl, 'yscale':yscale, 'vlines': vlines, 'annotate': annotate , 'xgrid':xgrid, 'ygrid':ygrid, 'kwargs': kwargs_tmp } output.append(deepcopy(o)) if len(output) == 1: return output[0] else: return output except Exception, e: print "Some mistake during get_data: " + str(e) #return #raise #pass def fix_str(string): """ FIx the string to avoid any problems during plotting """ if isnone(string): return "" if not re.match(".*\$.*", string): # hot fix of "latex font" if you can use $ solve it yourselfi string = re.sub('([^\\\])_', r'\1\_', string) # basic latex friendly fix_string return string ###################### WEB PAGES ############################## def get_page_paths(shot, page, default_path = 'shots'): """ Return basic paths need for web pages generation """ base_path = default_path + '/' + str(shot)+'/' [ page_path, page] = os.path.split(page) page_path = base_path + page_path return page_path,base_path, page def emph(text): return ''+str(text)+'' def modified(file): try: return datetime.datetime.fromtimestamp(os.stat(file)[-1]).strftime('%Y-%m-%d') except: return 'N/A' def make_image(img_path, name = ""): """ Show image from `img_path` with title `name`. Use loading animation of not avalible. Use svgz if possible and if the browser is capable to load. """ img_path = re.sub("(.+)(\.[\w]{3,4})$",r"\1", img_path) # remove file ending (.png) out = "" if name is not "": out += "

"+name+"

\n" #full_path = './'+img_path full_path = img_path rand_end = "?%s"%random.randint(1e6) out += """
"; } else { echo "%s
"; } } else { echo "
"; echo " " ; } ?>""" % (full_path,full_path,full_path,rand_end,name,full_path,rand_end,name,full_path ) return out def make_zoom_image(img_path, name = "", group = ""): """ Load image from `img_path` with title 'name`. Make image zoomable, if possible use svgz version. """ rand_end = "?%s"%random.randint(1e6) img_path = re.sub("(.+)(\.[\w]{3,4})$",r"\1", img_path) # remove file ending (.png) out = "" if name is not "": out += "

"+name+"

\n" out += """ """ + \ ""+name + (" " return out def make_config(path): """ Try to nicely load and show config from text file in shape value = 1 value2 = 3 """ try: cfg = cat(path, [], True) out =""" """ for line in cfg: # replace : or = by new cell out += ' ' out += "
' + re.sub('(.+)(\:|=)(.+)', r'\1 \3', line) + '
" except: out = "

Missing config file " + path + "

" return out def wiki(link): """ Short cat to geenrate link to wiki """ if re.match("^\/", link): # link to internal wiki, starts with / return "WIKI" else: return "WIKI" def source_link(path , revision,default_path, link, name): """ Link to formatted link in bitbucket.org """ path = re.sub("^"+default_path+"\/[\d]+\/", "", path) path = re.sub("\/\/", "\/", path + '/') # remove shots/ + shot_number from path return '[' +name+ ']' ################################# Historical analysis ############################################################# def get_history(diagn, shots, verbose=True, dtype = "float" ): """ Rapidly loads and returns values of scalar variables from selected shots """ # load default_path => either vshots or shots N = len(shots) Data = Shot(shots[-1]) shot = Data.shot_num if os.path.exists(shotdir + '/' + str(shot) + '/' + diagn): path = shotdir + '/' + str(shot) + '/' + diagn else: path = Data.get_data(diagn, return_path=True) data = zeros(N, dtype=dtype) try: data *= NaN except: pass for i in range(N): path_tmp = re.sub(str(shot), str(shots[i]), path) if os.path.exists(path_tmp) and os.path.getsize(path_tmp) > 0: try: f = open( path_tmp , 'r').read() #if not re.match( ".*[A-Za-z\+].*", f): # no letters in the string #data[i] = float( f ) #elif re.match('OK', f): # fix for plasma status #data[i] = 1 try: data[i] = float( f ) except: data[i] = f except Exception, e: print diagn, "ERROR :", str(e), 'shot', shots[i] if verbose: sys.stdout.write("\r %2.1f %%" %(i/double(N)*100)) sys.stdout.flush() sys.stdout.write("\r") return data ############################# Advanced algorithms, do not touch :) ################################################# # Author: Tomas Odstrcil # Date: 2012 def DiffFilter(signal,dt, win,lam): from scipy.signal import medfilt, fftconvolve #loadování zabere většinu času # signal - data; dt - sampling time of the signal, lam - regularization parameter, win - incresed length due to edge effets signal = array(signal.T, copy = False,ndmin = 2).T signal = medfilt(signal,(11,1)) #extend the signal due to edge effects extended_signal = zeros((size(signal,0)+win,size(signal,1))) extended_signal[win/2:-(win)/2,:] = signal extended_signal[:win/2,:] = median(signal[ :win/4,:],axis=0) extended_signal[-win/2-1:,:] = median(signal[-win/4:,:],axis=0) N = size(extended_signal,0) #calculate (II + lam*DD)^{-1}*I^T ale ve fourierovské doméně! csignal = -cumsum(extended_signal, axis = 0,out = extended_signal) fsig = fft.rfft(csignal,axis = 0).T DDfft = (4*sin(linspace(0, pi/2, N/2+1))**2) IIfft = 1/(DDfft+DDfft[1]*1e-4) A = IIfft + lam*DDfft fsig /= A f = fft.irfft(fsig.T,axis = 0) #remove edge effects f = f[win/2-1:-(win+1)/2,:] retrofit = cumsum(f, axis = 0) chi2 = sum((retrofit-signal)**2, axis = 0)/size(signal,0) return f/dt, retrofit, chi2 def DiffFilter_old(signal,dt, max_win= 2000, regularization = 1e6): from scipy.signal import medfilt, fftconvolve Nsig = size(signal,1) if ndim(signal) == 2 else 1 signal = reshape(signal, (-1, Nsig)) signal = medfilt(signal,(11, 1)) lam = regularization*(dt/5e-6)**-4 win = 15*(24*lam)**0.25 if win > max_win: print 'Warning: too small window' win = min(win, max_win) #to je otypovaná závislost, win = int(win)+4 Integ = tril( ones((win, win)), 0) II = dot(Integ.T,Integ) D = diag(ones(win))-diag(ones(win-1),1) DD = dot(D.T,D) b = zeros(win) b[win/2] = 1 B = dot(Integ.T,b) A = II + lam*DD c = array(linalg.solve( A,B.T)) extended_signal = zeros((size(signal,0)+win-1, Nsig))#ones(len(signal)+win-1) extended_signal[win/2:-(win)/2+1,:] = signal extended_signal[:win/2,:] += median(signal[:win/4],axis=0) extended_signal[-win/2:,:] += median(signal[-win/4:],axis=0) diffsig = empty(shape(signal)) for i in range(Nsig): diffsig[:,i] = fftconvolve(extended_signal[:,i] , c, mode = 'valid')/dt retrofit = cumsum(diffsig, axis = 0)*dt retrofit += mean(signal[:int(4e-3/dt),:], axis = 0) #FIXME #plot(diffsig) ##figure() #plot(c) #show() #plot(signal) #plot(retrofit) #plot(signal-retrofit) #show() chi2 = sum((retrofit-signal)**2, axis = 0)/len(signal) return diffsig, squeeze(retrofit), chi2 def GapsFilling(signal,win = 100, lam = 1e-1): """ =============================== Gaps Filling Filter 0.1 ===================== reconstruct the corrupted data (data with nans) by tikhonov-philips regularization with regulariting by laplace operator. And return smoothed retrofit Reconstruction is based on the invertation of the identical operator with zeros at the lines corresponding to the mising signal. due to memory and speed limitation the reconstruction is done on the overalaping intervals with width "win" signal - long data vector win - width of the recosntruction interval - it mas be much bigger than the gaps width lam - regularization parameter, dependes on the noise in data Autor: Tomas Odstrcil 2012 """ from scikits.sparse.cholmod import cholesky, analyze,cholesky_AAt from scipy.sparse import spdiags, eye n = len(signal) # signal extension -- avoid boundary effects n_ext = (n/win+1)*win ext_signal = zeros(n_ext) ext_signal[(n_ext-n)/2:-(n_ext-n)/2] = signal ext_signal[:(n_ext-n)/2+1] = median(signal[~isnan(signal)][:win/2]) ext_signal[-(n_ext-n)/2-1:] = median(signal[~isnan(signal)][win/2:]) intervals = arange(0,n+win/2, win/2) ind_nan = isnan(ext_signal) ext_signal[ind_nan] = 0 recon = copy(ext_signal) diag_data = ones((2,win)) diag_data[1,:]*=-1 D = spdiags(diag_data, (0,1), win, win,format='csr') DD = D.T*D I = eye(win,win,format='csc') Factor = cholesky_AAt(DD, 1./lam) for i in range(len(intervals)-2): gaps = spdiags( int_(ind_nan[intervals[i]:intervals[i+2]]),0, win, win,format='csc') # use overlapping intervals !!! Factor.update_inplace(gaps/sqrt(lam), subtract=True) # speed up trick g = Factor(ext_signal[intervals[i]:intervals[i+2]]/lam) Factor.update_inplace(gaps/sqrt(lam), subtract=False) # speed up trick recon[(intervals[i]+intervals[i+1])/2:(intervals[i+1]+intervals[i+2])/2] = g[len(g)/4:-len(g)/4,0] chi2 = sum((ext_signal-recon)[~ind_nan]**2)/n recon = recon[(n_ext-n)/2:-(n_ext-n)/2] #x = arange(n) #plot(x,recon) #plot(x,signal, '.') #show() return recon, chi2 def deconvolution( signal,responseFun, win, regularization): from scipy.signal import fftconvolve from scipy.sparse import spdiags from numpy.matlib import repmat """ Hlavní algoritmu provádějící dekonvoluci vrací to dekonvoluovaný signál a zpětnou rekonstrukci je to zasložené na jednoduché metodě maximalizující hladkost okno win musí být větší než "support" funkce g """ win = 2*(win/2)+1 responseFun/=sum(responseFun) # vytvoří se konvoluční matice - matice odezvy systému na delta funkci i_middle = argmax(responseFun) diags = arange(i_middle-len(responseFun),+ i_middle) #zjistit kam to posouvá diag_data = repmat(responseFun[::-1], win,1).T ConvMatrix = spdiags(diag_data, diags, win, win).todense() D = diag(ones(win))-diag(ones(win-1),1) DD = dot(D.T,D) #příprava dat do hlavného algoritmu na dekonvoluci, řeší se soustava Tg = f T = matrix(ConvMatrix) TT = dot(T.T,T) f = zeros(win) f[win/2] = 1 responseFun2 = zeros(2*len(responseFun)) responseFun2[len(responseFun):] = responseFun responseFun = responseFun2 g = linalg.solve(TT +regularization*DD,dot(T.T,f).T ) g = squeeze(array(g)) extended_signal = zeros(len(signal)+win-1) extended_signal[win/2:-win/2+1] = signal extended_signal[:win/2] += median(signal[:win/4],axis = 0) extended_signal[-win/2+1:] += median(signal[-win/4:],axis = 0) deconv_sig = fftconvolve(extended_signal ,g, mode = 'valid') retrofit = fftconvolve(deconv_sig ,responseFun, mode = 'same') #plot(g) #plot(responseFun) #show() #plot(deconv_sig) #plot(retrofit, '--') #plot(signal) #show() return deconv_sig, retrofit #Ich = load('Ich.npy') #tvec = Ich[:,0] #signal = Ich[:,1] #dt = tvec[1]-tvec[0] #t_exp = 1e-4 #responseFun = exp(-tvec[:int(6*t_exp/dt)]/t_exp) #win = 100 #deconvolution( signal,responseFun, win, 1e-2) #deconvolve / convolve with exponential kernel def deconvolveExp( signal,t_exp,dt,win,regularization): from scipy.signal import fftconvolve tvec = arange(len(signal))*dt responseFun = exp(-tvec[:int(6*abs(t_exp)/dt)]/abs(t_exp)) responseFun/=sum(responseFun) if abs(t_exp) <= 2*dt: return signal, signal elif t_exp < 0: responseFun2 = zeros(2*len(responseFun)) responseFun2[len(responseFun):] = responseFun responseFun = responseFun2 conv_sig = fftconvolve(signal,responseFun, mode = 'full') nc = len(conv_sig) n = len(signal) conv_sig = conv_sig[(nc-n)/2 :-(nc-n)/2] return conv_sig, signal else: return deconvolution( signal,responseFun, win, regularization)