#!/usr/bin/python2
# -*- coding: utf-8 -*-
from numpy import *
import pyfftw
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(30)
#print 'include time ',time()-t
def BlockConv(signalin, impulse,mode='full'):
S = len(impulse) # impulse length
L_ = len(signalin) # signal length
if S > L_/4:
return fftconvolve(signalin, impulse,mode)
n_ext = (((L_+1)/S)+1)*S-L_
signalin = hstack((signalin, zeros(n_ext)))
L = len(signalin) # signal length
# find fftsize as the next power of 2
# beyond S*2-1
N = max(2**int(log2(S)+2),128)
# signal blocks for FFT
imp = zeros(N,dtype= impulse.dtype)
#sig = zeros(N,dtype=signalin.dtype)
sig = pyfftw.n_byte_align_empty(N, 16,dtype=csingle)
#sig = fftw3.create_aligned_array(N)
# output signal
out = pyfftw.n_byte_align_empty(N, 16,dtype=csingle)
#out = fftw3.create_aligned_array(N)
sigout = zeros((L/2)*2+S+1, dtype=complex)
# copy impulse for FFT
imp[:S] = impulse
spec_imp = fft.fft(imp)
#fft_forward = fftw3.Plan(sig,out, direction='forward', flags=['estimate',])
fft_forward = pyfftw.FFTW(sig,out, direction='FFTW_FORWARD', flags=['FFTW_ESTIMATE','FFTW_DESTROY_INPUT'])
#fft_backward = fftw3.Plan(out,sig, direction='backward', flags=['estimate',])
fft_backward = pyfftw.FFTW(out,sig, direction='FFTW_BACKWARD', flags=['FFTW_ESTIMATE','FFTW_DESTROY_INPUT'])
for i in xrange( L/S):
p = S*i
# get block from input
sig[:S] = signalin[p:p+S]
sig[S:] = 0
# perform convolution and overlap-add
fft_forward()
out*= spec_imp
fft_backward()
sigout[p:p+2*S] += sig[:2*S]/len(sig)
if mode == 'full':
sigout = sigout[:-n_ext-1]
if mode == 'same':
sigout = sigout[(S-1)/2:(S-1)/2+L_]
return sigout