#!/usr/bin/python2
# -*- coding: utf-8 -*-
""" CREATED: 7/2012
AUTHOR: MICHAL ODSTRCIL
"""
from numpy import *
from config import *
import sys
#try:
#from matplotlib.pyplot import *
#except:
#pass
def import_matplotlib():
import matplotlib
matplotlib.rcParams['backend'] = 'Agg'
matplotlib.rc('font', size='10')
matplotlib.rc('text', usetex=True) # FIXME !! nicer but slower !!!
import matplotlib.pyplot as plt
return plt
try:
plt = import_matplotlib()
except:
pass
import os, re, datetime, time
from pygolem_lite import *
from copy import deepcopy
from pygolem_lite import isnone
###################### SIMPLE DATA LOADING ##############################
def cat(path, lines = [0], return_array = False):
""" read file from `path` and return string. If there is several lines, it can return array if `return_array == True`
"""
#try:
with open(path, 'r') as f:
content_tmp = f.readlines()
if return_array: return content_tmp
if len(lines) == 0:
lines = range(len(content_tmp))
content = ""
for i in lines:
content += content_tmp[i]
#except:
#content = 'N/A'
return content #return the raw string
####################### OTHER USEFUL ROUTINES #################################x
def nanmedian(data):
return median(data[~isnan(data)])
def list2array(L):
"""
convert list output from pygolem to array / scalar
"""
if type(L) is not list and type(L) is not tuple :
return L
if ndim(L[1]) == 1:
dim = 1
else:
dim = size(L[1],1)
out = empty([len(L[0]), dim+1])
out[:,0] = L[0]
out[:,1:] = reshape(L[1], (-1, dim))
return out
def random_number_downsample(tvec, data, N):
"""
downsample data without moire effects
"""
Ndata = len(tvec)
if Ndata <= N:
return tvec, data
ind = int_( linspace(0,Ndata-1, N))
ind_rand = ind + (random.randint(0,Ndata/N, size = (N)) if Ndata/N > 0 else 0)
ind_rand[-1] = Ndata-1
ind_rand = unique(ind_rand)
return tvec[ind_rand], data[ind_rand,:]
###################### PLOTTING ##############################
def multiplot(data_all, title_name, file_name, figsize = (9,None), dpi = 100, orientation = 'vertical', file_type = 'png', reduction = False, debug_info = True):
""" Plot several graphs to on multiplot. Use data `data_all` provided by function get_data !!
:var dict data_all: Input data or list of input data
:var str title_name: Title of the image
:var str file_name: Name of saved image (without filetype)
:var list figsize: Size of the output image
:var int dpi: DPI of output image
:var str orientation: Choose multiplot orientation: vertical/horizontal
:var str file_type: Image file type, recomended: png, svgz !! svgz is slow !!
:var bool reduction: Decrease number of plotted points => speed up plotting, use for all nice smooth lines !!
:var bool debug_info: Allow printing of some debug info
"""
from matplotlib.ticker import MultipleLocator, FormatStrFormatter,LogLocator, MaxNLocator
plt = import_matplotlib()
class MyFormatter(plt.ScalarFormatter): # improved format of axis
def __call__(self, x, pos=None):
self.set_scientific(True)
self.set_useOffset(True)
self.set_powerlimits((-2, 3))
#print x, pos, ScalarFormatter, ScalarFormatter()
if pos==0:
return ''
else:
return plt.ScalarFormatter.__call__(self, x, pos)
is_icon = dpi <= 50
if all(isnone(data_all)):
#print "data_all", data_all
raise IOError('Empty input data set for plotting')
else:
data_all = deepcopy(data_all) # avoid mysterious changes in plots
if type(data_all) is dict:
data_all = [data_all]
i = 0 # !! remove empty fields !!
while i < len(data_all):
if len(data_all[i]) == 0 or all(isnone( data_all[i] )) :
data_all.pop(i)
else:
i += 1
n_plots = len(data_all)
t0 = time.time()
if figsize[1] is None: # if the size is not determined, make it flexible
figsize = array(figsize)
figsize[1] = 3*n_plots
fig = plt.figure(num=None, figsize=figsize, edgecolor='k')
plt.subplots_adjust(wspace=0.3,hspace=0)
for i,data in enumerate(data_all):
if orientation is 'vertical':
ax = fig.add_subplot(n_plots,1,i+1)
ax.xaxis.set_major_formatter( plt.NullFormatter() )
ax.yaxis.set_major_formatter( MyFormatter() ) # improved format of axis
else:
ax = fig.add_subplot(1,n_plots,i+1)
if is_icon: # very small image
reduction = True
try:
_plot_adv(data, reduction, ax, dpi)
except Exception, e:
raise NotImplementedError("plotting failed", str(e), "\n" )# , data
try: # problem with log scale
ax.ticklabel_format(style='plain', axis='y')
ax.yaxis.set_major_locator( MaxNLocator(max(3,int(figsize[1]/double(n_plots)*5) )))
except:
pass
ax.tick_params(axis='both' if not is_icon else 'none',reset=False,which='both',length=2,width=0.5)
if i == 0 and title_name != "" and not is_icon:
ax.set_title(fix_str(title_name))
if orientation is 'horizontal':
plt.xlabel(data['xlabel'] if type(data) == dict else data[0]['xlabel'])
if orientation is 'vertical':
ax = fig.add_subplot(n_plots ,1,n_plots)
ax.xaxis.set_major_formatter( plt.ScalarFormatter())
#ax.yaxis.set_major_formatter( plt.ScalarFormatter())
#ax.xaxis.set_minor_locator( MultipleLocator(1) )
if not is_icon:
try:
plt.xlabel( fix_str(data['xlabel'] if type(data) == dict else data[0]['xlabel']) )
except Exception, e:
print "plotting xlabel failed", str(e), "\n", data
t0 = time.time()
file_name = re.sub('(.+)\.(.+)$', r'\1', re.sub('(.+)\.(.+)$', r'\1', file_name))
image_name = file_name+'.'+file_type
plt.savefig( image_name , dpi= dpi, bbox_inches='tight') # better but slower
plt.clf()
if file_type in ['png', 'jpg']:
os.system('convert '+ image_name +" -trim " + image_name )
if debug_info:
print "saving", time.time() - t0
print "plotting time %.2fs" % (time.time() - t0)
def paralel_multiplot(*kargs, **kwargs):
import multiprocessing
try:
multiprocessing.Process(target=multiplot, args=kargs, kwargs=kwargs).start()
except:
print "PLOTTING FAILED:" + kargs[0]
def _plot_adv(data, reduction, ax, dpi):
""" advanced plotter with dictionary input,
internal function, do not use directly
"""
if type(data) is dict:
data = [data]
is_icon = dpi <= 50
i = 0 # !! remove empty fields !!
while i < len(data):
if all(isnone(data[i])) or len(data[i]) == 0:
data.pop(i)
else:
i += 1
if len(data) == 0: # empty list
return
use_legend = False
N_data = len(data)
t0 = time.time()
for i in range(N_data):
if len(data[i]) == 0 or all(isnan(data[i]['data'])):
continue
d = data[i]
kwargs = d['kwargs']
ind_plot = ones(len(d['tvec']), dtype = bool)
if kwargs['label'] != "":
use_legend = True
if 'xlim' in d.keys(): # speed up plotting
ind_plot = where( (d['xlim'][0] <= d['tvec']) & ( d['xlim'][1] >= d['tvec']))
ind_plot = concatenate([ind_plot[0]-1, squeeze(ind_plot), ind_plot[-1]+1]) # extend the data by one point from the range
ind_plot = ind_plot[(ind_plot >= 0) & (ind_plot < len(d['tvec'])) ]
ind_plot = unique(ind_plot)
if reduction:
ind_plot = ind_plot[:: max(1, len(d['tvec'])/2000) ]
# downsampling of errobar data !!!
for name in ['xerr','yerr']:
if not kwargs[name] is None and not isscalar(kwargs[name]):
kwargs[name] = [kwargs[name][ind_plot,0], kwargs[name][ind_plot,1] ]
plt.errorbar(d['tvec'][ind_plot], d['data'][ind_plot], linewidth=0.5, capsize=0, **kwargs )
if data[0]['yscale'] != '':
ax.set_yscale(data[0]['yscale'])
if amax(data[0]['tvec']) - amin(data[0]['tvec']) < 100:
ax.xaxis.set_minor_locator(plt.MultipleLocator(1))
if 'vlines' in data[0].keys():
for line in data[0]['vlines']:
plt.axvline(x=line, color = 'b', linestyle='--')
for key in ['xstart', 'xend']:
if key in data[0].keys():
plt.axvline(x=data[0][key], color = 'k', linestyle='--')
if 'xlim' in data[0].keys():
Xrange =[min([data[i]['xlim'][0] for i in range(N_data)]),\
max([data[i]['xlim'][1] for i in range(N_data)])]
plt.xlim(Xrange)
if 'ylim' in data[0].keys():
ymin = inf ; ymax = -inf
for i in range(N_data):
ymin_tmp, ymax_tmp = _detect_range(data[i]['tvec'] , data[i]['data'], Xrange, data[i]['ylim'])
ymin = min(ymin_tmp,ymin)
ymax = max(ymax_tmp,ymax)
Yrange = [ymin, ymax]
plt.ylim(Yrange)
if not is_icon: # in case of small icon do no apply
if data[0]['ylabel'] != "" and data[0]['ylabel']:
plt.ylabel(data[0]['ylabel'])
if 'annotate' in data[0].keys():
for annot in data[0]['annotate']:
if annot['pos'][0] is None: annot['pos'][0] = Xrange[0] + (Xrange[1] - Xrange[0])*0.05
if annot['pos'][1] is None: annot['pos'][1] = Yrange[1] - (Yrange[1] - Yrange[0])*0.05
plt.annotate(annot['text'], annot['pos'], rotation=annot['angle'], fontsize=annot['fontsize'] )
if 'ygrid' in data[0].keys() and data[0]['ygrid']:
ax.yaxis.grid(color='gray', linestyle='dashed')
if 'xgrid' in data[0].keys() and data[0]['xgrid']:
ax.xaxis.grid(color='gray', linestyle='dashed')
if use_legend:
#leg = legend()
# FIXME !! nicer but slower !!!
leg = plt.legend(loc='best' , fancybox=True)
leg.get_frame().set_alpha(0.7)
return ax
def _detect_range(tvec, data, xlim, ylim):
""" smart way how to determine best ylim in the plots
internal fucntion do not use directly
"""
from scipy.stats.mstats import mquantiles
ind_ok = ~isnan(data) & ~isinf(data) & (tvec > xlim[0]) & (tvec < xlim[1])
yrange = [mquantiles(data[ind_ok], 0.005), mquantiles(data[ind_ok], 0.995)]
dy = abs(yrange[1] - yrange[0])
if dy/amax(abs(array(yrange))) < 1e-6: # constant line ...
yrange[0] = min(0, yrange[0]*1.1 )
yrange[1] = max(0, yrange[1]*1.1 )
else:
yrange[0] -= 0.2*dy
yrange[1] += 0.2*dy
for i in [0,1]:
if ylim[i] is None:
ylim[i] = yrange[i]
return ylim
def get_data(diagn, diagn_name = None, ylabel = None, xlabel = 'Time [ms]', shot_num = None, xlim = [], ylim = [], \
tvec_rescale = 1e3, data_rescale = 1, integrated = False, columns = [], yscale = '', \
reduction = False, plot_limits = True, smoothing = 0, line_format = "", \
vlines = [], annotate = [], xgrid= False, ygrid = False, **kwargs ):
"""
Load data/diagn and return object that is used for plotting.
:var str/list/Data class diagn: Input can be string with name of diagn from main database, name of file with locally saved data on disk, list with [tvec,data]
:var str diagn_name: Name of diagnostics that will be used in legend, if diagn has more signal than name will be `name 0`, `name 1`, ...., or diagn_name can be list with name for each column
:var str ylabel: Title of ylabel
:var str xlabel: Title of xlabel
:var list xlim: Range for X axis, use [1.132 ,None] to define only one limit
:var list ylim: Range for Y axis, use [1.132 ,None] to define only one limit
:var int shot_num: Number of shot, default is autodetect from current folder
:var double tvec_rescale: Rescale X axis to ms !! Be careful if the axis is not time in seconds !!
:var double data_rescale: Rescale Y axis
:var bool reduction: Try to reduce number of plotted points, automatically used for integrated signal
:var bool integrated: Integrate the input signal
:var list columns: Plot only the selected columns
:var str yscale: scale of y-axis, use `log`
:var bool plot_limits: If there is plasma, plot lines at start and end of the plasma
:var double smoothing: Smooth the signal
:var str line_format: Standard formating string
:var bool xgrid: Plot automatic xgrid in the plot
:var bool ygrid: Plot automatic ygrid in the plot
:var array xerr: error in X direction
:var array yerr: error in Y direction
:var **kwargs : keyword arguments to be passed to the standard plot() command, optional
this method accepts **kwargs as a standard :func:`matplotlib.pyplot.plot` function
Advanced:
:var list vlines : list of positions of vertical lines in plot (see historical analysis)
:var list annotate: list of dicts => plt.annotate(annot['text'], annot['pos'], rotation=annot['angle'] )
"""
try:
# process input setting
if line_format != "":
kwargs['fmt'] = line_format
cData = Shot(shot_num)
is_valid = False
#return "srs"
if type(diagn) is str: # if diagn
is_valid = cData._is_valid_name(diagn)
if is_valid == 0: # not in database
[tvec,data] = load_adv(diagn) # prefere locally saved data
elif is_valid in (1,2): # data
[tvec,data] = cData[diagn]
elif is_valid == 3: # DAS, channel_name
[tvec,data] = cData["any", diagn]
else:
raise NotImplementedError("ambiguous diagn name")
elif (type(diagn) is list or type(diagn) is tuple ) and len(diagn) == 2:
[tvec,data] = deepcopy(diagn)
elif type(diagn) is Data:
[tvec, data] = [diagn.tvec, diagn]
else:
raise NotImplementedError, "Unsupported input"
# Errobars ===================
if type(data) is Data or not 'yerr' in kwargs or 'xerr' in kwargs:
# 'xerr' not user defined in the function call
if 'xerr' not in kwargs : kwargs['xerr'] = data.tvec_err if type(data) is Data else None
if 'yerr' not in kwargs : kwargs['yerr'] = data.data_err if type(data) is Data else None
for name in ['xerr', 'yerr']:
err = kwargs[name]
if not any(isnone(err)) and ndim(err)==0: err = float(err)
if not any(isnone(err)) and not isscalar(err):
err = array(squeeze(err), ndmin=2)
if size(err,0) == 1: err = err.T
if size(err,1) == 1:
err = err.repeat(2, axis=1)
kwargs[name] = err
else:
kwargs['xerr'] = None; kwargs['yerr'] = None;
# =================================
if smoothing > 0:
_, data, _ = DiffFilter(data, mean(diff(tvec)), 2000, 1e6)
if len(columns) > 0:
data = data[:,columns]
data *= data_rescale
tvec *= tvec_rescale
if not any(isnone(kwargs['yerr'])): kwargs['yerr'] *= data_rescale
if not any(isnone(kwargs['xerr'])): kwargs['xerr'] *= tvec_rescale
Ncol = size(data,1) if ndim(data) == 2 else 1
data = reshape(data, (-1,Ncol)) # workaround for 1D arrays
if integrated:
dt = mean(diff(tvec))
data = cumsum(data,0)*dt
if reduction or integrated or smoothing > 10:
tvec, data = random_number_downsample(tvec, data, 4000) # downsample data to 4000 points
# default values
plasma = False
start = -inf
start_pl = nan
end_pl = nan
end = inf
try:
plasma = cData['plasma']
except:
pass
if plasma and plot_limits and tvec_rescale == 1e3: # tvec_rescale == 1e3 => x axis is probably time !!!
start_pl = cData['plasma_start']*1e3
end_pl = cData['plasma_end']*1e3
if cData['ucd'] > 10:
start = min(start_pl - 2, cData['tcd']*1e3) # 2ms before plasma start
else:
start = start_pl - 2
end = end_pl + 2
elif tvec_rescale == 1e3:
try:
start = cData['tb']*1e3-1
end = 35
except:
pass
t_range = amax(tvec) - amin(tvec)
if len(xlim) == 0:
xlim = array([max(start, amin(tvec)-t_range*0.05), min(end, amax(tvec)+t_range*0.05)])
if xlim[0] > xlim[1]: # some errror (no intersection)
xlim = array([amin(tvec)-t_range*0.05, amax(tvec)+t_range*0.05])
if type(data) is Data and ylabel is None:
if not all(isnone(data.ax_labels)):
ylabel = data.ax_labels[1]
if type(data) is Data and diagn_name is None:
diagn_name = data.info
output = list()
columns = range(Ncol)
if len(ylim) == 0:
ylim = [None, None]
### ============ auto-labeling =====================
for i in columns:
column_name = None
if len(diagn_name) == len(columns) and type(diagn_name) is list : # every column has its own label
column_name = diagn_name[i]
else:
if is_valid == 2: # use predefined name in config if it is DAS
column_name = cData.channel_name(diagn, i)
if column_name is None: # nothing matched ...
if len(columns) == 1: # only one variable
column_name = diagn_name
elif len(diagn_name) > 0: # any name
column_name = diagn_name+ ' '+str(i)
else:
column_name = ""
if "label" in kwargs: kwargs.pop( "label" ) ## BUG FIXME !!!!
kwargs_tmp = dict(label=fix_str(column_name), **kwargs)
o = {'tvec': tvec , 'data':data[:,i], 'ylabel':fix_str( ylabel ),
'xlabel': fix_str(xlabel), 'ylim':ylim, 'xlim':xlim, 'xstart':start_pl, 'xend':end_pl, 'yscale':yscale,
'vlines': vlines, 'annotate': annotate , 'xgrid':xgrid, 'ygrid':ygrid,
'kwargs': kwargs_tmp }
output.append(deepcopy(o))
if len(output) == 1:
return output[0]
else:
return output
except Exception, e:
print "Some mistake during get_data: " + str(e)
#return
#raise
#pass
def fix_str(string):
""" FIx the string to avoid any problems during plotting """
if isnone(string): return ""
if not re.match(".*\$.*", string): # hot fix of "latex font" if you can use $ solve it yourselfi
string = re.sub('([^\\\])_', r'\1\_', string) # basic latex friendly fix_string
return string
###################### WEB PAGES ##############################
def get_page_paths(shot, page, default_path = 'shots'):
"""
Return basic paths need for web pages generation
"""
base_path = default_path + '/' + str(shot)+'/'
[ page_path, page] = os.path.split(page)
page_path = base_path + page_path
return page_path,base_path, page
def emph(text):
return '<tt class="file docutils literal"><span class="pre">'+str(text)+'</span></tt>'
def modified(file):
try:
return datetime.datetime.fromtimestamp(os.stat(file)[-1]).strftime('%Y-%m-%d')
except:
return 'N/A'
def make_image(img_path, name = ""):
""" Show image from `img_path` with title `name`. Use loading animation of not avalible. Use svgz if possible and if the browser is capable to load.
"""
img_path = re.sub("(.+)(\.[\w]{3,4})$",r"\1", img_path) # remove file ending (.png)
out = ""
if name is not "":
out += "<h4>"+name+"</h4>\n"
#full_path = './'+img_path
full_path = img_path
rand_end = "?%s"%random.randint(1e6)
out += """<?php
$u_agent = $_SERVER['HTTP_USER_AGENT'];
if (file_exists("%s.png"))
{
// in firefox use svgz images !!
if (file_exists("%s.svgz") && (preg_match('/Firefox/i',$u_agent) || preg_match('/Chrome/i',$u_agent) ))
{
echo "<img src='%s.svgz%s' alt='%s' align='middle'/><br/>";
}
else
{
echo "<img src='%s.png%s' alt='%s' align='middle'/><br/>";
}
}
else
{
echo "<img src='/_static/loading.gif'/><br/>";
echo "<!-- Missing file %s.png --> " ;
}
?>""" % (full_path,full_path,full_path,rand_end,name,full_path,rand_end,name,full_path )
return out
def make_zoom_image(img_path, name = "", group = ""):
""" Load image from `img_path` with title 'name`. Make image zoomable, if possible use svgz version.
"""
rand_end = "?%s"%random.randint(1e6)
img_path = re.sub("(.+)(\.[\w]{3,4})$",r"\1", img_path) # remove file ending (.png)
out = ""
if name is not "":
out += "<h4>"+name+"</h4>\n"
out += """<?php
$u_agent = $_SERVER['HTTP_USER_AGENT'];
if (file_exists('""" + img_path + ".svgz" + """') && (preg_match('/Firefox/i',$u_agent) || preg_match('/Chrome/i',$u_agent) )){
$path='"""+ img_path +".svgz"+rand_end+"""'; }
else {
$path='"""+ img_path + ".png"+rand_end+"""'; }
?> """ + \
"<a href='<?php echo $path ?>' rel='lightbox"""+group+"'><img src='"+ img_path + ".png"+rand_end + "' alt='"+name + ("' title='"+name if name != "" else "") + "'/></a> "
return out
def make_config(path):
""" Try to nicely load and show config from text file in shape
value = 1
value2 = 3
"""
try:
cfg = cat(path, [], True)
out ="""
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
"""
for line in cfg:
# replace : or = by new cell
out += ' <tr class="field"><th class="field-name"> ' + re.sub('(.+)(\:|=)(.+)', r'\1 </th><td class="field-body"> \3', line) + ' </td> </tr>'
out += "</tbody></table>"
except:
out = "<h4> Missing config file " + path + "</h4>"
return out
def wiki(link):
"""
Short cat to geenrate link to wiki
"""
if re.match("^\/", link): # link to internal wiki, starts with /
return "<a href='http://golem.fjfi.cvut.cz:5001"+link+"'><font size=1><b><sup>WIKI</sup></b></font></a>"
else:
return "<a href='"+link+"'><font size=1><b><sup>WIKI</sup></b></font></a>"
def source_link(path , revision,default_path, link, name):
"""
Link to formatted link in bitbucket.org
"""
path = re.sub("^"+default_path+"\/[\d]+\/", "", path)
path = re.sub("\/\/", "\/", path + '/') # remove shots/ + shot_number from path
return '<a href="https://bitbucket.org/michalodstrcil/golem_velin/src/'+ revision +'/includes/'+link+ '">[' +name+ ']</a>'
################################# Historical analysis #############################################################
def get_history(diagn, shots, verbose=True, dtype = "float" ):
"""
Rapidly loads and returns values of scalar variables from selected shots
"""
# load default_path => either vshots or shots
N = len(shots)
Data = Shot(shots[-1])
shot = Data.shot_num
if os.path.exists(shotdir + '/' + str(shot) + '/' + diagn):
path = shotdir + '/' + str(shot) + '/' + diagn
else:
path = Data.get_data(diagn, return_path=True)
data = zeros(N, dtype=dtype)
try:
data *= NaN
except:
pass
for i in range(N):
path_tmp = re.sub(str(shot), str(shots[i]), path)
if os.path.exists(path_tmp) and os.path.getsize(path_tmp) > 0:
try:
f = open( path_tmp , 'r').read()
#if not re.match( ".*[A-Za-z\+].*", f): # no letters in the string
#data[i] = float( f )
#elif re.match('OK', f): # fix for plasma status
#data[i] = 1
try:
data[i] = float( f )
except:
data[i] = f
except Exception, e:
print diagn, "ERROR :", str(e), 'shot', shots[i]
if verbose:
sys.stdout.write("\r %2.1f %%" %(i/double(N)*100))
sys.stdout.flush()
sys.stdout.write("\r")
return data
############################# Advanced algorithms, do not touch :) #################################################
# Author: Tomas Odstrcil
# Date: 2012
def DiffFilter(signal,dt, win,lam):
from scipy.signal import medfilt, fftconvolve #loadování zabere většinu času
# signal - data; dt - sampling time of the signal, lam - regularization parameter, win - incresed length due to edge effets
signal = array(signal.T, copy = False,ndmin = 2).T
signal = medfilt(signal,(11,1))
#extend the signal due to edge effects
extended_signal = zeros((size(signal,0)+win,size(signal,1)))
extended_signal[win/2:-(win)/2,:] = signal
extended_signal[:win/2,:] = median(signal[ :win/4,:],axis=0)
extended_signal[-win/2-1:,:] = median(signal[-win/4:,:],axis=0)
N = size(extended_signal,0)
#calculate (II + lam*DD)^{-1}*I^T ale ve fourierovské doméně!
csignal = -cumsum(extended_signal, axis = 0,out = extended_signal)
fsig = fft.rfft(csignal,axis = 0).T
DDfft = (4*sin(linspace(0, pi/2, N/2+1))**2)
IIfft = 1/(DDfft+DDfft[1]*1e-4)
A = IIfft + lam*DDfft
fsig /= A
f = fft.irfft(fsig.T,axis = 0)
#remove edge effects
f = f[win/2-1:-(win+1)/2,:]
retrofit = cumsum(f, axis = 0)
chi2 = sum((retrofit-signal)**2, axis = 0)/size(signal,0)
return f/dt, retrofit, chi2
def DiffFilter_old(signal,dt, max_win= 2000, regularization = 1e6):
from scipy.signal import medfilt, fftconvolve
Nsig = size(signal,1) if ndim(signal) == 2 else 1
signal = reshape(signal, (-1, Nsig))
signal = medfilt(signal,(11, 1))
lam = regularization*(dt/5e-6)**-4
win = 15*(24*lam)**0.25
if win > max_win:
print 'Warning: too small window'
win = min(win, max_win) #to je otypovaná závislost,
win = int(win)+4
Integ = tril( ones((win, win)), 0)
II = dot(Integ.T,Integ)
D = diag(ones(win))-diag(ones(win-1),1)
DD = dot(D.T,D)
b = zeros(win)
b[win/2] = 1
B = dot(Integ.T,b)
A = II + lam*DD
c = array(linalg.solve( A,B.T))
extended_signal = zeros((size(signal,0)+win-1, Nsig))#ones(len(signal)+win-1)
extended_signal[win/2:-(win)/2+1,:] = signal
extended_signal[:win/2,:] += median(signal[:win/4],axis=0)
extended_signal[-win/2:,:] += median(signal[-win/4:],axis=0)
diffsig = empty(shape(signal))
for i in range(Nsig):
diffsig[:,i] = fftconvolve(extended_signal[:,i] , c, mode = 'valid')/dt
retrofit = cumsum(diffsig, axis = 0)*dt
retrofit += mean(signal[:int(4e-3/dt),:], axis = 0) #FIXME
#plot(diffsig)
##figure()
#plot(c)
#show()
#plot(signal)
#plot(retrofit)
#plot(signal-retrofit)
#show()
chi2 = sum((retrofit-signal)**2, axis = 0)/len(signal)
return diffsig, squeeze(retrofit), chi2
def GapsFilling(signal,win = 100, lam = 1e-1):
"""
=============================== Gaps Filling Filter 0.1 =====================
reconstruct the corrupted data (data with nans) by tikhonov-philips regularization with regulariting by laplace operator. And return smoothed retrofit
Reconstruction is based on the invertation of the identical operator with zeros at the lines corresponding to the mising signal.
due to memory and speed limitation the reconstruction is done on the overalaping intervals with width "win"
signal - long data vector
win - width of the recosntruction interval - it mas be much bigger than the gaps width
lam - regularization parameter, dependes on the noise in data
Autor: Tomas Odstrcil 2012
"""
from scikits.sparse.cholmod import cholesky, analyze,cholesky_AAt
from scipy.sparse import spdiags, eye
n = len(signal)
# signal extension -- avoid boundary effects
n_ext = (n/win+1)*win
ext_signal = zeros(n_ext)
ext_signal[(n_ext-n)/2:-(n_ext-n)/2] = signal
ext_signal[:(n_ext-n)/2+1] = median(signal[~isnan(signal)][:win/2])
ext_signal[-(n_ext-n)/2-1:] = median(signal[~isnan(signal)][win/2:])
intervals = arange(0,n+win/2, win/2)
ind_nan = isnan(ext_signal)
ext_signal[ind_nan] = 0
recon = copy(ext_signal)
diag_data = ones((2,win))
diag_data[1,:]*=-1
D = spdiags(diag_data, (0,1), win, win,format='csr')
DD = D.T*D
I = eye(win,win,format='csc')
Factor = cholesky_AAt(DD, 1./lam)
for i in range(len(intervals)-2):
gaps = spdiags( int_(ind_nan[intervals[i]:intervals[i+2]]),0, win, win,format='csc') # use overlapping intervals !!!
Factor.update_inplace(gaps/sqrt(lam), subtract=True) # speed up trick
g = Factor(ext_signal[intervals[i]:intervals[i+2]]/lam)
Factor.update_inplace(gaps/sqrt(lam), subtract=False) # speed up trick
recon[(intervals[i]+intervals[i+1])/2:(intervals[i+1]+intervals[i+2])/2] = g[len(g)/4:-len(g)/4,0]
chi2 = sum((ext_signal-recon)[~ind_nan]**2)/n
recon = recon[(n_ext-n)/2:-(n_ext-n)/2]
#x = arange(n)
#plot(x,recon)
#plot(x,signal, '.')
#show()
return recon, chi2
def deconvolution( signal,responseFun, win, regularization):
from scipy.signal import fftconvolve
from scipy.sparse import spdiags
from numpy.matlib import repmat
"""
Hlavní algoritmu provádějící dekonvoluci
vrací to dekonvoluovaný signál a zpětnou rekonstrukci
je to zasložené na jednoduché metodě maximalizující hladkost
okno win musí být větší než "support" funkce g
"""
win = 2*(win/2)+1
responseFun/=sum(responseFun)
# vytvoří se konvoluční matice - matice odezvy systému na delta funkci
i_middle = argmax(responseFun)
diags = arange(i_middle-len(responseFun),+ i_middle) #zjistit kam to posouvá
diag_data = repmat(responseFun[::-1], win,1).T
ConvMatrix = spdiags(diag_data, diags, win, win).todense()
D = diag(ones(win))-diag(ones(win-1),1)
DD = dot(D.T,D)
#příprava dat do hlavného algoritmu na dekonvoluci, řeší se soustava Tg = f
T = matrix(ConvMatrix)
TT = dot(T.T,T)
f = zeros(win)
f[win/2] = 1
responseFun2 = zeros(2*len(responseFun))
responseFun2[len(responseFun):] = responseFun
responseFun = responseFun2
g = linalg.solve(TT +regularization*DD,dot(T.T,f).T )
g = squeeze(array(g))
extended_signal = zeros(len(signal)+win-1)
extended_signal[win/2:-win/2+1] = signal
extended_signal[:win/2] += median(signal[:win/4],axis = 0)
extended_signal[-win/2+1:] += median(signal[-win/4:],axis = 0)
deconv_sig = fftconvolve(extended_signal ,g, mode = 'valid')
retrofit = fftconvolve(deconv_sig ,responseFun, mode = 'same')
#plot(g)
#plot(responseFun)
#show()
#plot(deconv_sig)
#plot(retrofit, '--')
#plot(signal)
#show()
return deconv_sig, retrofit
#Ich = load('Ich.npy')
#tvec = Ich[:,0]
#signal = Ich[:,1]
#dt = tvec[1]-tvec[0]
#t_exp = 1e-4
#responseFun = exp(-tvec[:int(6*t_exp/dt)]/t_exp)
#win = 100
#deconvolution( signal,responseFun, win, 1e-2)
#deconvolve / convolve with exponential kernel
def deconvolveExp( signal,t_exp,dt,win,regularization):
from scipy.signal import fftconvolve
tvec = arange(len(signal))*dt
responseFun = exp(-tvec[:int(6*abs(t_exp)/dt)]/abs(t_exp))
responseFun/=sum(responseFun)
if abs(t_exp) <= 2*dt:
return signal, signal
elif t_exp < 0:
responseFun2 = zeros(2*len(responseFun))
responseFun2[len(responseFun):] = responseFun
responseFun = responseFun2
conv_sig = fftconvolve(signal,responseFun, mode = 'full')
nc = len(conv_sig)
n = len(signal)
conv_sig = conv_sig[(nc-n)/2 :-(nc-n)/2]
return conv_sig, signal
else:
return deconvolution( signal,responseFun, win, regularization)