// -----------------------------------------------------------------------------
// Copyright 2008 Steve Hanov. All rights reserved.
//
// For permission to use, please contact steve.hanov@gmail.com. Permission will
// usually be granted without charge.
// -----------------------------------------------------------------------------
#include <math.h>
#include "ScalogramCreator.h"
#include "Waveform.h"
#include "fftw3.h"
#include "SoundImage.h"
#include <complex>
#include "dbg.h" // must be last

class Profiler
{
public:
    unsigned ticks;
    void start() {
        ticks = GetTickCount();
    }

    unsigned reset() {
        unsigned current = GetTickCount();
        unsigned ret = current - ticks;
        ticks = current;
        return ret;
    }
};

/**
  * Find the next highest power of two. For example, if 40 it passed in, it
  * will return 64.
  */
static int round_2_up(int num)
{
    return (int)pow( 2, ceil(log((double)num)/log((double)2)) );
}

/**
  * Calculate the morlet wavelet in the Fourier domain, using the given scale
  * and f0 values.
  */
static float FTWavelet( float x, float scale, float f0 )
{
    //if ( x < 0.9 / scale ||  x > 1.1 / scale ) {
    //    return (float)0.0;
    //}

    static const float pi = (float)3.14159265358979323846;
    static const double two_pi_f0 = 2.0 * pi * f0;
    static const double multiplier = 1.8827925275534296252520792527491;

    scale *= (float)f0;

    // 1.88279*exp(-0.5*(2*pi*x*10-2*pi*10)^2)

    float basic = (float)(multiplier *
            exp(-0.5*(2*pi*x*scale-two_pi_f0)*(2*pi*x*scale-two_pi_f0)));

    // pi^0.25*sqrt(2.0)*exp(-0.5*(2*pi*x*scale-2*pi*0.849)^2)
    return sqrt(scale)*basic;
}

ScalogramCreator::ScalogramCreator( TaskObserver* observer ) :
    Task<ScalogramCreatorParams>::Task( observer )
{

}

ScalogramCreator::~ScalogramCreator()
{

}

void
ScalogramCreator::run( ScalogramCreatorParams params )
{
    unsigned setupTicks = 0;
    unsigned multTicks = 0;
    unsigned fftwTicks = 0;
    unsigned addTicks = 0;
    unsigned renderTicks = 0;

    Profiler prof;
    prof.start();

    unsigned avoid_overlap = (unsigned)(params.upperScale * 20);

    double df = pow(params.upperScale/params.lowerScale, 1.0 /
            (params.upperScale-params.lowerScale));
    unsigned N = round_2_up( params.wave->size + avoid_overlap )*2;

    params.image->upperScale = (float)params.upperScale;
    params.image->lowerScale = (float)params.lowerScale;

    // Iniitalize the fast fourier transform.
    // See fftwf3 documentation online.
    fftwf_plan plan_forward;
    fftwf_plan plan_inverse;
    fftwf_complex* data = (fftwf_complex*) fftwf_malloc(sizeof(fftwf_complex) * N);
    fftwf_complex* ans = (fftwf_complex*) fftwf_malloc(sizeof(fftwf_complex) * N);

    plan_inverse = fftwf_plan_dft_1d(N, ans, ans, FFTW_BACKWARD,  FFTW_ESTIMATE);
    plan_forward = fftwf_plan_dft_1d(N, data, data, FFTW_FORWARD, FFTW_ESTIMATE );

    memset( data, 0, sizeof(fftwf_complex)*N);
    memset( ans, 0, sizeof(fftwf_complex)*N);

    printf("\n");

    for( unsigned i =0 ; i < params.wave->size; i++ ) {
        data[i][0] = (float)params.wave->samples[i];
    }

    fftwf_execute(plan_forward);

    setupTicks += prof.reset();

    // for each scale level,
    int row = 0;
    int ticks = GetTickCount();
    // for each scale factor, 
    for ( double period = params.lowerScale; period <= params.upperScale; period *= df, row += 1 ) 
    {
        unsigned total = setupTicks+fftwTicks+multTicks+addTicks+renderTicks;
        printf("%d: %g setup:%.0f%% fftw:%.0f%% mult:%.0f%% add:%.0f%% render:%.0f%%     \r", row,
                period,
                (float)setupTicks/total*100, (float)fftwTicks/total*100,
                (float)multTicks/total*100, (float)addTicks/total*100,
                (float)renderTicks/total*100
                );
        // Only update the progress twice a second. 
        if ( GetTickCount() -ticks > 500 ) {   
            setProgress( (double)row / params.image->bands );
            ticks = GetTickCount();
        }

        if ( _stopRequested ) {
            break;
        }

        // Multiply the fourier transform of the sound with the fourier
        // transform of the wavelet.
        memset( ans, 0, sizeof(fftwf_complex)*N);
        int start = (unsigned)(0.9 * N / period);
        int end = (unsigned)(1.1 * N / period);
        for( int x =start; x < end; x++ ) {
            ans[x][0] = FTWavelet( (float)x, (float)period/(N), (float)params.f0 ) * data[x][0];
            //ans[i][1] = 0;
            //ans[N-i-1][0] = 0;
            //ans[N-i-1][1] = 0;
        }

        multTicks += prof.reset();

        // Perform inverse fourier transform of the result.
        fftwf_execute(plan_inverse);

        fftwTicks += prof.reset();
        params.image->addRow( row, (float*)ans );

        addTicks += prof.reset();
        // Render the row to the DibImage on the fly, so the user doesn't get
        // bored.
        params.image->renderRowToDib( row, params.dib, params.type, params.map, 0,
            params.wave->size, (float*)ans );
        renderTicks += prof.reset();
    }

    fftwf_destroy_plan(plan_forward);
    fftwf_destroy_plan(plan_inverse);
    fftwf_free(data); 
    fftwf_free(ans);
}
