import numpy as np 
import pandas as pd
import argparse
import json
from pathlib import Path
import os
import time
import torch
from scipy.linalg import eig
import sys
import datetime
sys.path.insert(0,"../../")

import src.dataset as dataset
from src.dataset import periodic_ts_1d,sin
from src.model import primal_fit_to
from src.utils import primal_left
from src.representation import polynomial_feature_map, augment
from src.kooporch import ChordalCostFunction, SpectralOTBarycenter,StandardBarycenter,HilbertSchmidtCostFunction,reduced_to_full_system,full_to_reduced_system


######################################################################################################
# Base signal settings

SETTINGS_1 = [[1,0.3,0,1.7,"sin"],[0.2,-0.2,0,4.7,"sin"]]
SETTINGS_2 = [[1,-0.2,0,0.7,"sin"],[0.2,0.3,0,11.3,"sin"]]
SETTINGS_INIT = [[1.0,0,0,1.7,"sin"],
                 [0.1,0,0,4.7,"sin"],
                 [0.01,0,0,0.7,"sin"],
                 [0.5,0,0,11.3,"sin"]]

######################################################################################################

#utilities
def primal_predict(T,X_init,n_samples,sampfreq,real_valued=True):
    D,R,L = T
    t = torch.arange(n_samples)/sampfreq
    Z = torch.exp(D[:,None] * t[None,:])
    A =Z *(L.conj().T @ X_init.reshape(-1,1))
    pred = (R @ A).T
    if real_valued:
        pred = pred.real
    return pred

def compute_spectral_decomposition(traj,sampfreq,context_window,poly_order,rank,tikhonov_reg,symmetry): 
    Z = polynomial_feature_map(augment(traj,context_window),order=poly_order)
    e = primal_fit_to(Z=Z, dt = 1/sampfreq, tikhonov_reg=tikhonov_reg,rank=rank,symmetry=symmetry)
    Ds = e['values']
    sort_indices = np.argsort(Ds.imag)
    Ds = Ds[sort_indices]
    Rs = e['right'][:,sort_indices]
    Ls = primal_left(e,Z)[:,sort_indices]
    return Ds,Rs,Ls

def real_baseline_barycenter(Tt_lst,rank,ponderation,sampfreq,step=1,dtype=torch.cfloat): 
    Tt_lst = [reduced_to_full_system(T) for T in Tt_lst]
    Tt_lst = [(R @ ( torch.exp(D * (step/sampfreq)).view(-1,1) * L.conj().T)) for D,R,L in Tt_lst]
    T_bar = torch.mean(ponderation.view(-1,1,1) * torch.stack(Tt_lst), dim=0)
    T_bar = T_bar.detach().cpu().numpy()
    D,L,R = eig(T_bar, left=True, right=True,overwrite_a=True,overwrite_b=True)
    sort_idx = np.argsort(np.abs(D))
    D = np.log(D[sort_idx]) / (step/sampfreq)
    return torch.tensor(D[sort_idx][:rank], dtype=dtype), torch.tensor(R[:,sort_idx][:, :rank], dtype=dtype), torch.tensor(L[:,sort_idx][:, :rank], dtype=dtype)

if __name__ == "__main__": 

    parser = argparse.ArgumentParser(description="Interpolation between 1D linear systems")

    parser.add_argument("--exp_id", type=int, default=1, help="Experiment id")
    parser.add_argument("--seed", type=int, default=2, help="Random seed for reproducibility")
    parser.add_argument("--save_folder", type=str, default="results", help="Folder to save results")

    parser.add_argument("--n_points", type=int, default=9, help="Number of points in the trajectory")
    parser.add_argument("--sampfreq", type=int, default=800, help="Sampling frequency")
    parser.add_argument("--n_samples", type=int, default=5000, help="Duration of the trajectory")
    parser.add_argument("--noise_std", type=float, default=0.01, help="Scale of the Brownian noise")
    parser.add_argument("--rank", type=int, default=4, help="Rank for the spectral decomposition")
    parser.add_argument("--context_window", type=int, default=400, help="Context window for the spectral decomposition")
    parser.add_argument("--poly_order", type=int, default=1, help="Polynomial order for the spectral decomposition")
    parser.add_argument("--settings_1", type=list, default=SETTINGS_1, help="Settings for signal 1 as a json string")
    parser.add_argument("--settings_2", type=list, default=SETTINGS_2, help="Settings for signal 2 as a json string")
    parser.add_argument("--settings_init", type=list, default=SETTINGS_INIT, help="Initial settings for the signals as a json string")

    parser.add_argument("--device", type=str, default="cpu", help="Device to run the computations on (e.g., 'cpu' or 'cuda')")
    parser.add_argument("--dtype", type=str, default="cfloat", help="Data type for computations (e.g., 'float', 'double', 'cfloat', 'cdouble')")
    parser.add_argument("--data-dtype", type=str, default="float", help="Data type for the input data (e.g., 'float', 'double', 'cfloat', 'cdouble')")
    parser.add_argument("--real_scale", type=float, default=1.0, help="Real scale for the cost matrix")
    parser.add_argument("--imag_scale", type=float, default=1.0, help="Imaginary scale for the cost matrix")
    parser.add_argument("--alpha", type=float, default=0.9, help="Alpha parameter for the cost matrix")
    parser.add_argument("--p", type=float, default=2.0, help="P parameter for the cost matrix")
    parser.add_argument("--ot_lr", type=float, default=1e-2, help="Learning rate for the optimization")
    parser.add_argument("--ot_max_epochs", type=int, default=200, help="Number of epochs for the optimization")
    parser.add_argument("--ot_max_iter", type=int, default=10, help="Maximum number of iterations per coordinate descent")
    parser.add_argument("--ot_tol", type=float, default=1e-6, help="Tolerance for convergence in the optimization")
    parser.add_argument("--ot_verbose", type=int, default=1, help="Verbosity level for the optimization process")
    parser.add_argument("--hs_lr", type=float, default=3e-5, help="Learning rate for the optimization")
    parser.add_argument("--hs_max_epochs", type=int, default=2000, help="Number of epochs for the optimization")
    parser.add_argument("--hs_max_iter", type=int, default=1, help="Maximum number of iterations per coordinate descent")
    parser.add_argument("--hs_tol", type=float, default=1e-6, help="Tolerance for convergence in the optimization")
    parser.add_argument("--hs_verbose", type=int, default=1, help="Verbosity level for the optimization process")
    parser.add_argument("--tikhonov_reg", type=float, default=1e-8, help="Tikhonov regularization parameter")

    args = parser.parse_args()

    # save experiment settings
    setting_file = Path(args.save_folder) / f"exp_{args.exp_id}" / "settings.json"
    if not setting_file.parent.is_dir():
        setting_file.parent.mkdir(parents=True)
    with open(setting_file, 'w') as f:
        json.dump(vars(args), f, indent=2)

    # process arguments
    args.dtype = getattr(torch, args.dtype)
    args.data_dtype = getattr(torch, args.data_dtype)

    def update_settings(x):
        fname = x.pop()
        return x + [getattr(dataset,fname)]

    args.settings_1 = [update_settings(x) for x in args.settings_1]
    args.settings_2 = [update_settings(x) for x in args.settings_2]
    args.settings_init = [update_settings(x) for x in args.settings_init]

    # generate experiment folder
    experiment_path = Path(args.save_folder) / f"exp_{args.exp_id}"
    if not experiment_path.is_dir():
        experiment_path.mkdir(parents=True)
    
    # random setting
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # Generate signals
    signal_1 = periodic_ts_1d(args.settings_1,args.n_samples,args.sampfreq,args.noise_std)
    signal_2 = periodic_ts_1d(args.settings_2,args.n_samples,args.sampfreq,args.noise_std)
    signal_init = periodic_ts_1d(args.settings_init,args.context_window+1,args.sampfreq,args.noise_std)
    signal_init_path = experiment_path / "signal_init.npy"
    np.save(signal_init_path, signal_init)
    # Compute spectral decompositions
    T1 = compute_spectral_decomposition(signal_1,args.sampfreq,args.context_window,args.poly_order,args.rank,1e-6,None)
    T1_path = experiment_path / "T1.npz"
    np.savez(T1_path, D=T1[0], R=T1[1], L=T1[2],signal = signal_1)
    T2 = compute_spectral_decomposition(signal_2,args.sampfreq,args.context_window,args.poly_order,args.rank,1e-6,None)
    T2_path = experiment_path / "T2.npz"
    np.savez(T2_path, D=T2[0], R=T2[1], L=T2[2],signal = signal_2)

    # system simplification
    T1 = [torch.tensor(x,device=args.device,dtype=args.dtype) for x in T1]
    T1 = full_to_reduced_system(T1)
    T2 = [torch.tensor(x,device=args.device,dtype=args.dtype) for x in T2]
    T2 = full_to_reduced_system(T2)

    # COMPUTE BARYCENTERS
    cost_function = ChordalCostFunction(alpha=args.alpha, real_scale=args.real_scale, imag_scale=args.imag_scale, p=args.p)
    T_lst = [T1,T2]
    lst = list(zip(*T_lst))
    ratios = torch.linspace(0,1,args.n_points+2)

    print("Computing barycenters ...")
    for i,ratio in enumerate(ratios[1:-1]):
        # Computing the ponderations
        ponderations = torch.tensor([1-ratio,ratio],dtype=args.data_dtype,device=args.device)

        # Computing the baseline barycenter
        base_path = experiment_path / "base" / f"base_bar_{i+1}.npz"
        if not base_path.parent.is_dir():
            base_path.parent.mkdir(parents=True)
        base_bar = real_baseline_barycenter(T_lst, rank=args.rank, ponderation=ponderations,sampfreq=args.sampfreq)
        numpy_base_bar = [x.detach().cpu().numpy() for x in base_bar]
        np.savez(base_path, D=numpy_base_bar[0], R=numpy_base_bar[1], L=numpy_base_bar[2])
        print(f"Base Barycenter {i+1}/{args.n_points} done.")

        # Computing the HS barycenter
        hs_path = experiment_path / "hs" / f"hs_bar_{i+1}.npz"
        if not hs_path.parent.is_dir():
            hs_path.parent.mkdir(parents=True)
        D_init = torch.mean(torch.stack(lst[0]), dim=0).detach().type(args.dtype).to(args.device)
        R_init = torch.mean(torch.stack(lst[1]), dim=0).detach().type(args.dtype).to(args.device)
        L_init = torch.mean(torch.stack(lst[2]), dim=0).detach().type(args.dtype).to(args.device)
        T_init = (D_init, R_init, L_init)
        hs_cost_function = HilbertSchmidtCostFunction(sampfreq=args.sampfreq,step=1)
        start_time = datetime.datetime.now()
        sba = StandardBarycenter(cost_function=hs_cost_function, dual=False, verbose=args.hs_verbose, lr=args.hs_lr, max_epochs=args.hs_max_epochs, max_iter=args.hs_max_iter, tol=args.hs_tol,device=args.device)
        T_bar, epoch_loss = sba.fit(Tt_lst=T_lst, init=T_init,ponderations=ponderations)
        end_time = datetime.datetime.now()
        duration = (end_time - start_time).total_seconds()
        numpy_hs_bar = [x.detach().cpu().numpy() for x in T_bar]
        np.savez(hs_path, D=numpy_hs_bar[0], R=numpy_hs_bar[1], L=numpy_hs_bar[2],losses = epoch_loss.detach().cpu().numpy(),duration = duration)
        print(f"HS Barycenter {i+1}/{args.n_points} done.")

        # Computing the OT barycenter
        ot_path = experiment_path / "ot" / f"ot_bar_{i+1}.npz"
        if not ot_path.parent.is_dir():
            ot_path.parent.mkdir(parents=True)
        D_init = torch.mean(torch.stack(lst[0]), dim=0).detach().type(args.dtype).to(args.device)
        R_init = torch.mean(torch.stack(lst[1]), dim=0).detach().type(args.dtype).to(args.device)
        L_init = torch.mean(torch.stack(lst[2]), dim=0).detach().type(args.dtype).to(args.device)
        T_init = (D_init, R_init, L_init)
        ot_start_time = datetime.datetime.now()
        soba = SpectralOTBarycenter(cost_function, lr = args.ot_lr,max_iter=args.ot_max_iter, max_epochs=args.ot_max_epochs, tol=args.ot_tol, verbose=args.ot_verbose, device=args.device, dual=False)
        T_bar, P_lst, losses = soba.fit(T_lst, T_init,ponderations=ponderations)
        ot_duration = (datetime.datetime.now() - ot_start_time).total_seconds()
        numpy_ot_bar = [x.detach().cpu().numpy() for x in T_bar]
        np.savez(ot_path, D=numpy_ot_bar[0], R=numpy_ot_bar[1], L=numpy_ot_bar[2], losses=losses.detach().cpu().numpy(),duration =ot_duration)
        print(f"OT Barycenter {i+1}/{args.n_points} done.")

    print("All barycenters computed.")
    
    