import argparse
import numpy as np
import pandas as pd
from functools import partial
from joblib import Parallel, delayed
from pathlib import Path
import os
import sys
sys.path.insert(0,"../../")

from src.representation import augment, polynomial_feature_map
from src.model import primal_fit_to
import src.dataset as dataset
from src.dataset import periodic_ts_1d, square_approx
from src.utils import primal_left
from src.numpy_metric import hs_metric,operator_metric,chordal_metric,eigenvalue_metric,subspace_metric,martin_metric


def _compute_spectral_decomposition(n_samples,sampfreq,noise_std,context_window,poly_order,rank,tikhonov_reg,symmetry,lst): 
    X = periodic_ts_1d(lst, n_samples, sampfreq,noise_std)
    Z = polynomial_feature_map(augment(X,context_window),order=poly_order)
    e = primal_fit_to(Z=Z, dt = 1/sampfreq, tikhonov_reg=tikhonov_reg,rank=rank,symmetry=symmetry)
    Ds = e['values']
    Rs = e['right']
    Ls = primal_left(e,Z)
    return Ds,Rs,Ls

if __name__ == "__main__": 

    print("Starting experiment...")
    
    parser = argparse.ArgumentParser(description="Test Experiment")

    #Experiment related parameters
    parser.add_argument("--exp_id", type=int, required=True, help="Experiment id")
    parser.add_argument('--n_trials',type=int, required=True, help="Number of metric computation trials.")

    parser.add_argument('--min_frequency', type=float,required=True, help="Minimum frequency of the observed signal. Possibly different from the baseline signal")
    parser.add_argument('--max_frequency', type=float,required=True, help="Maximum frequency of observed signals. Possibly different from the baseline signal")
    parser.add_argument('--n_step_frequency', type=int,required=True, help="Min to max freq in n_step")
    
    parser.add_argument('--min_decay', type=float,required=True, help="Minimum Decay of the observed signals. Possibly different from the baseline signal")
    parser.add_argument('--max_decay', type=float,required=True, help="Maximum decay of the observed signals. Possibly different from the baseline signal")
    parser.add_argument('--n_step_decay', type=int,required=True, help="Min to max decay in n_step")

    parser.add_argument('--min_power', type=float,required=True, help="Minimum power of the cos_to_squared function. Possibly different from the baseline signal")
    parser.add_argument('--max_power', type=float,required=True, help="Maximum power of the cos_to_squared function. Possibly different from the baseline signal")
    parser.add_argument('--n_step_power', type=int,required=True, help="Min to max power in n_step")

    parser.add_argument('--min_sampfreq', type=int, required=True, help="Minimum sampling frequency in Hertz")
    parser.add_argument('--max_sampfreq', type=int, required=True, help="Maximum sampling frequency in Hertz")
    parser.add_argument('--n_step_sampfreq', type=int, required=True, help="Min to max sampling frequency in n_step")

    #Metric configuration
    parser.add_argument('--alpha', type=float, default=[0.75,0.5,0.25,0.1,0.01], nargs="+", help='Chordal cost function alpha parameter')
    parser.add_argument('--real_scale', type=float, default=10.0, help='Chordal cost function real part scale')
    parser.add_argument('--imag_scale', type=float, default=1.0, help='Chordal cost function imaginary part scale')

    #Baseline signal configuration
    parser.add_argument('--base_freq', type=float, default=[0.5,1.0], nargs="+",help='Baseline frequency')
    parser.add_argument('--base_amp', type=float, default=[1.0,1.0], nargs="+", help='Baseline amplitude')
    parser.add_argument('--base_decay', type=float, default=[0.0,0.0], nargs="+", help='Baseline decay')
    parser.add_argument('--base_phase', type=float, default=[0.0,0.0], nargs="+", help='Baseline phase')
    parser.add_argument('--base_func', type=str, default=["cos","cos"], nargs="+", help='Baseline function')
    parser.add_argument('--base_std', type=float, default=1e-2, help='Baseline noise standard deviation')
    parser.add_argument('--base_sampfreq', type=int, default=200, help='Sampling frequency in Hertz')
    parser.add_argument('--n_samples', type=int, default=4001, help='Number of samples per generated signal')
    
    #Baseline koopman solver configuration
    parser.add_argument('--poly_order',type=int, default=1, help='Maximum polynomial degree for observables functions')
    parser.add_argument('--rank', type=int, default=2, help='Rank of the estimated koopman operator')
    parser.add_argument('--n_jobs', type=int, default=1, help='Maximum number of concurrently running jobs')
    parser.add_argument('--tikhonov_reg', type=float, default=1e-8, help='Tikhonov regularisation value')
    parser.add_argument('--symmetry', type=str, default=None, help='Koopman operator type: ["symmetric", "antisymmetric", None]')
    parser.add_argument('--seed', type=int, default=None, help="random seed")
    
    #folders 
    parser.add_argument('--save_folder', type=str, default="results", help= "Folder where results will be saved")

    args = parser.parse_args()

    #Seeding
    if not args.seed is None: 
        np.random.seed(args.seed)

    #build functions
    def compute_all_metrics(Ds, Rs, Ls, Dt, Rt, Lt,metrics,metrics_names):
        results = {}
        for metric, name in zip(metrics, metrics_names):
            results[name] = metric(Ds, Rs, Ls, Dt, Rt, Lt)
        return results

    base_context_window = int(1/np.max(args.base_freq) * args.base_sampfreq)
    base_context_dt = base_context_window * 1/args.base_sampfreq
    base_metrics= [
        partial(hs_metric,sampfreqs=base_context_dt,sampfreqt=base_context_dt),
        partial(operator_metric,sampfreqs=base_context_dt,sampfreqt=base_context_dt),
        partial(eigenvalue_metric,sampfreqs=base_context_dt,sampfreqt=base_context_dt),
        partial(subspace_metric,sampfreqs=base_context_dt,sampfreqt=base_context_dt),
        partial(martin_metric,sampfreqs=base_context_dt,sampfreqt=base_context_dt)
        ]
    metrics_names = ["HS","operator","eigenvalue","subspace","martin"]
    for alpha in args.alpha:
        base_metrics.append(partial(chordal_metric,alpha=alpha,real_scale=args.real_scale,imag_scale=args.imag_scale))
        metrics_names.append(f"chordal_{int(100*alpha)}")

    base_compute_spectral_decomposition = partial(_compute_spectral_decomposition,args.n_samples,args.base_sampfreq,args.base_std,base_context_window,args.poly_order,args.rank,args.tikhonov_reg,args.symmetry)


    #[amplitude,decay,phase,frequency,function]
    assert len(args.base_amp) == len(args.base_decay) == len(args.base_phase) == len(args.base_freq) == len(args.base_func), "Base arguments (amplitude, decay, phase, frequency, function) must have the same dimensions"
    n_modes = len(args.base_amp)
    

    #Create experiment folder
    dir_path = Path(args.save_folder)
    if not dir_path.is_dir():
        try:
            dir_path.mkdir(parents=True)
        except: 
            pass
    
    # Frequency ablation study
    filename = f"frequency_ablation_study.csv"
    if not os.path.exists(dir_path / filename):
        print("Computing frequency ablation study...")
        freqs = np.linspace(args.min_frequency,args.max_frequency,args.n_step_frequency)
        def freq_single_run(freq):
            config_s = []
            config_t = []
            for i,(b_amp,b_decay,b_phase,b_freq,b_func) in enumerate(zip(args.base_amp,args.base_decay,args.base_phase,args.base_freq,args.base_func)):
                if i != n_modes-1: 
                    config_s.append([b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                    config_t.append( [b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                else: 
                    config_s.append([b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                    config_t.append([b_amp,b_decay,b_phase,freq,getattr(dataset,b_func)])
            Ts = base_compute_spectral_decomposition(config_s) #Ts = (Ds,Rs,Ls)
            Tt = base_compute_spectral_decomposition(config_t) #Tt = (Dt,Rt,Lt)
            return compute_all_metrics(*Ts, *Tt, base_metrics, metrics_names)
        results = Parallel(args.n_jobs)(delayed(freq_single_run)(freq) for freq in freqs for _ in range(args.n_trials))
        df = pd.DataFrame(results)
        df["frequency"] = freqs.tolist()*args.n_trials
        path = dir_path / filename
        df.to_csv(path)
        print(f"Frequency ablation study completed.")
    else: 
        print("Frequency ablation study already computed, skipping...")

    # Decay ablation study
    filename = f"decay_ablation_study.csv"
    if not os.path.exists(dir_path / filename):
        print("Computing decay ablation study...")
        decays = np.linspace(args.min_decay,args.max_decay,args.n_step_decay)
        def decay_single_run(decay):
            config_s = []
            config_t = []
            for i,(b_amp,b_decay,b_phase,b_freq,b_func) in enumerate(zip(args.base_amp,args.base_decay,args.base_phase,args.base_freq,args.base_func)):
                if i != n_modes-1: 
                    config_s.append([b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                    config_t.append( [b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                else: 
                    config_s.append([b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                    config_t.append([b_amp,decay,b_phase,b_freq,getattr(dataset,b_func)])
            Ts = base_compute_spectral_decomposition(config_s) #Ts = (Ds,Rs,Ls)
            Tt = base_compute_spectral_decomposition(config_t) #Tt = (Dt,Rt,Lt)
            return compute_all_metrics(*Ts, *Tt, base_metrics, metrics_names)
        results = Parallel(args.n_jobs)(delayed(decay_single_run)(decay) for decay in decays for _ in range(args.n_trials))
        df = pd.DataFrame(results)
        df["decay"] = decays.tolist()*args.n_trials
        path = dir_path / filename
        df.to_csv(path)
        print(f"Decay ablation study completed.")
    else: 
        print("Decay ablation study already computed, skipping...")

    # mode ablation study
    filename = f"mode_ablation_study.csv"
    if not os.path.exists(dir_path / filename):
        print("Computing mode ablation study...")
        assert args.max_power < args.base_sampfreq, "Maximum power should be less than half the sampling frequency to avoid aliasing"
        mode_compute_spectral_decomposition = partial(_compute_spectral_decomposition,n_samples=args.n_samples,sampfreq=args.base_sampfreq,noise_std=args.base_std,context_window=base_context_window,poly_order=args.poly_order,tikhonov_reg=args.tikhonov_reg,symmetry=args.symmetry)
        def mode_single_run(power): 
            config_s = []
            config_t = []
            for i,(b_amp,b_decay,b_phase,b_freq,b_func) in enumerate(zip(args.base_amp,args.base_decay,args.base_phase,args.base_freq,args.base_func)):
                if i != n_modes-1: 
                    config_s.append([b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                    config_t.append( [b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                else: 
                    config_s.append([b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
                    config_t.append([b_amp,b_decay,b_phase,b_freq,lambda x : square_approx(power,x)])
            Ts = mode_compute_spectral_decomposition(rank = args.rank, lst = config_s)
            Tt = mode_compute_spectral_decomposition(rank = args.rank + 2*(power-1), lst = config_t)
            return compute_all_metrics(*Ts, *Tt, base_metrics, metrics_names)
        powers = np.linspace(args.min_power,args.max_power,args.n_step_power,dtype=int)
        results = Parallel(args.n_jobs)(delayed(mode_single_run)(power) for power in powers for _ in range(args.n_trials))
        df = pd.DataFrame(results)
        df["power"] = powers.tolist()*args.n_trials
        path = dir_path / filename
        df.to_csv(path)
        print(f"Mode ablation study completed.")
    else: 
        print("Mode ablation study already computed, skipping...")

    # Sampling frequency ablation study
    filename = f"sampfreq_ablation_study.csv"
    if not os.path.exists(dir_path / filename):
        print("Computing sampling frequency ablation study...")
        sampfreqs = np.linspace(args.min_sampfreq,args.max_sampfreq,args.n_step_sampfreq,dtype=int)
        def sampfreq_single_run(sf): 
            #context_window = int(1/np.max(args.base_freq) * sf)
            context_dt = base_context_window * 1/sf
            metrics = [
                partial(hs_metric,sampfreqs=base_context_dt,sampfreqt=context_dt),
                partial(operator_metric,sampfreqs=base_context_dt,sampfreqt=context_dt),
                partial(eigenvalue_metric,sampfreqs=base_context_dt,sampfreqt=context_dt),
                partial(subspace_metric,sampfreqs=base_context_dt,sampfreqt=context_dt),
                partial(martin_metric,sampfreqs=base_context_dt,sampfreqt=context_dt)
                ]
            for alpha in args.alpha:
                metrics.append(partial(chordal_metric,alpha=alpha,real_scale=args.real_scale,imag_scale=args.imag_scale))
            config = []
            for b_amp,b_decay,b_phase,b_freq,b_func in zip(args.base_amp,args.base_decay,args.base_phase,args.base_freq,args.base_func):
                config.append([b_amp,b_decay,b_phase,b_freq,getattr(dataset,b_func)])
            compute_spectral_decomposition = partial(_compute_spectral_decomposition,args.n_samples,sf,args.base_std,base_context_window,args.poly_order,args.rank,args.tikhonov_reg,args.symmetry)
            Ts = base_compute_spectral_decomposition(config) #Ts = (Ds,Rs,Ls)
            Tt = compute_spectral_decomposition(config) #Tt = (Dt,Rt,Lt)
            return compute_all_metrics(*Ts, *Tt, metrics, metrics_names)
        results = Parallel(args.n_jobs)(delayed(sampfreq_single_run)(sf) for sf in sampfreqs for _ in range(args.n_trials))
        df = pd.DataFrame(results)
        df["sampfreq"] = sampfreqs.tolist()*args.n_trials
        path = dir_path / filename
        df.to_csv(path)
        print(f"Sampling frequency ablation study completed.")
        
    else: 
        print("Sampling frequency ablation study already computed, skipping...")
    
    print("Experiment completed.")