#!/usr/bin/python2
# -*- coding: utf-8 -*-

from numpy import *

import pyfftw
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(30)
#print 'include time ',time()-t 



def BlockConv(signalin, impulse,mode='full'):
    
    
    S = len(impulse) # impulse length
    L_ = len(signalin) # signal length
    
    if S > L_/4:
        return fftconvolve(signalin, impulse,mode)
    
    n_ext = (((L_+1)/S)+1)*S-L_
    
    signalin = hstack((signalin, zeros(n_ext)))
    L = len(signalin) # signal length

    # find fftsize as the next power of 2
    # beyond S*2-1

    N = max(2**int(log2(S)+2),128)


    # signal blocks for FFT
    imp = zeros(N,dtype= impulse.dtype)
    #sig = zeros(N,dtype=signalin.dtype)
    sig = pyfftw.n_byte_align_empty(N, 16,dtype=csingle)
    #sig = fftw3.create_aligned_array(N)

    # output signal
    out = pyfftw.n_byte_align_empty(N, 16,dtype=csingle)
    #out = fftw3.create_aligned_array(N)

    sigout = zeros((L/2)*2+S+1, dtype=complex)
    
    # copy impulse for FFT
    imp[:S] = impulse
    spec_imp = fft.fft(imp)
    
    #fft_forward = fftw3.Plan(sig,out, direction='forward', flags=['estimate',])
    fft_forward = pyfftw.FFTW(sig,out, direction='FFTW_FORWARD', flags=['FFTW_ESTIMATE','FFTW_DESTROY_INPUT'])
    
    
    #fft_backward = fftw3.Plan(out,sig, direction='backward', flags=['estimate',])
    fft_backward = pyfftw.FFTW(out,sig, direction='FFTW_BACKWARD', flags=['FFTW_ESTIMATE','FFTW_DESTROY_INPUT'])
   

    for i in xrange( L/S):

        p = S*i
        # get block from input
        sig[:S] = signalin[p:p+S]
        sig[S:] = 0

        # perform convolution and overlap-add   
        fft_forward()
        out*= spec_imp
        fft_backward()

        sigout[p:p+2*S] += sig[:2*S]/len(sig)


    if mode == 'full':
        sigout = sigout[:-n_ext-1]
    if mode == 'same':
        sigout = sigout[(S-1)/2:(S-1)/2+L_]

    return sigout