import numpy as np
from huggingface_hub import login, hf_hub_download, list_repo_files
import os
from pathlib import Path
import shutil
import matplotlib.pyplot as plt
import sys
sys.path.append('../../')

from src.kooporch import ChordalCostFunction, SpectralOTBarycenter
from src.utils import primal_left
import torch
from src.model import primal_fit_to

######################################################################################################
######################################################################################################
# PARAMETERS # 

# connection to Hugging Face
hf_token = "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"

# Specify repository and folder details
repo_id = "BGLab/FlowBench"  # Repository ID on Hugging Face
dataset_path = "FPO_NS_2D_1024x256/"  # main Folder
# dataset to download

paths = [
    "harmonics/93/Re_145.npz",
    "harmonics/93/input_geometry.npz",
    "skelenton/48/Re_156.npz",
    "skelenton/48/input_geometry.npz",
]


save_path = Path("datasets")


######################################################################################################
######################################################################################################


if __name__ == "__main__": 

    #download data if needed
    for path in paths: 
        dirname = os.path.dirname(path).replace("/","_")
        filename = os.path.basename(path)
        dir_path = save_path/dirname
        file_path = save_path/dirname/filename
        if not dir_path.is_dir(): 
            dir_path.mkdir(parents=True)
        if not file_path.is_file(): 
            if hf_token is None: 
                raise ValueError("Please provide a Hugging Face token.")
            login(token=hf_token)
            cached = hf_hub_download(repo_id=repo_id, filename=dataset_path+path, repo_type="dataset")
            shutil.copy(cached, file_path)
            print("Saved:", file_path)
        else: 
            print(f"{path} already downloaded")

    #load data
    data_s_path = save_path/"harmonics_93/Re_145.npz"
    data_s = np.load(data_s_path)
    raw_data_s = data_s["data"]
    mask_s = np.load(save_path/"harmonics_93/input_geometry.npz")["mask"]
    mask_s = np.ma.masked_where(mask_s > 0.5, mask_s)
    data_t_path = save_path/"skelenton_48/Re_156.npz"
    data_t = np.load(data_t_path)
    raw_data_t = data_t["data"]
    mask_t = np.load(save_path/"skelenton_48/input_geometry.npz")["mask"]
    mask_t = np.ma.masked_where(mask_t > 0.5, mask_t)
    extent =  (0,mask_s.shape[1],0,mask_s.shape[0])

    # Preprocess data
    data_t = raw_data_t[:,::4,::4,0]
    traj_t= data_t.reshape(data_t.shape[0], -1)

    data_s = raw_data_s[:,::4,::4,0]
    traj_s= data_s.reshape(data_s.shape[0], -1)

    #estimate operators
    print("Fitting source data")
    evd_s = primal_fit_to(traj_s,1/100,1,1e-0,4)
    print("Fitting target data")
    evd_t = primal_fit_to(traj_t,1/100,1,1e-0,4)
    print("Fitting done")

    #estimate barucenter
    B_DTYPE = torch.cfloat
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    D_s,R_s,L_s = evd_s["values"], evd_s["right"], primal_left(evd_s, traj_s)
    T_s = (D_s,R_s,L_s)
    T_s = [torch.tensor(x, dtype=B_DTYPE, device=DEVICE) for x in T_s]
    D_t,R_t,L_t = evd_t["values"], evd_t["right"], primal_left(evd_t, traj_t)
    T_t = (D_t,R_t,L_t)
    T_t = [torch.tensor(x, dtype=B_DTYPE, device=DEVICE) for x in T_t]
    D_init = torch.tensor((D_s + D_t)/2, dtype=B_DTYPE, device=DEVICE)
    R_init = torch.tensor((R_s + R_t)/2, dtype=B_DTYPE, device=DEVICE)
    L_init = torch.tensor((L_s + L_t)/2, dtype=B_DTYPE, device=DEVICE)
    T_init = (D_init, R_init, L_init)
    cost_fn = ChordalCostFunction(alpha=0.01)
    soba = SpectralOTBarycenter(cost_fn, lr=1e-4, max_iter=10, max_epochs=100, tol=1e-6, verbose=True)
    weights = torch.tensor([0.5, 0.5], dtype=torch.float32, device=DEVICE)
    T_bar, P_lst, losses = soba.fit([T_s, T_t],T_init,ponderations=weights)

    #generate figure
    plt.rcParams["font.family"] = "Helvetica"
    plt.rcParams["xtick.labelsize"] = 8
    plt.rcParams["ytick.labelsize"] = 8
    plt.rcParams["axes.labelsize"] = 11
    plt.rcParams["legend.fontsize"] = 11

    fig,ax = plt.subplots(3,3, figsize=(10,3))

    idxs = [0,1,3]
    for i,idx in enumerate(idxs):
        ax[i,0].imshow(T_s[1].detach().cpu().numpy()[:,idx].real.reshape(data_s.shape[1], data_s.shape[2]),cmap='jet',extent=extent)
        ax[i,0].imshow(mask_s, interpolation='bilinear', cmap='gray', alpha=1, extent=extent)
        ax[i,1].imshow(T_bar[1].detach().cpu().numpy()[:,idx].real.reshape(data_s.shape[1], data_s.shape[2]),cmap='jet',extent=extent)
        ax[i,2].imshow(T_t[1].detach().cpu().numpy()[:,idx].real.reshape(data_s.shape[1], data_s.shape[2]),cmap='jet',extent=extent)
        ax[i,2].imshow(mask_t, interpolation='bilinear', cmap='gray', alpha=1, extent=extent)

    for a in ax.flatten():
        a.set_xticks([])
        a.set_yticks([])

    for i in range(3): 
        ax[i,0].set_ylabel(f'Eig. fct {i+1}', fontsize=12)

    ax[0,0].set_title(r'$\mathbf{T}^{(0)}$ $(\gamma = 0.0)$', fontsize=14)
    ax[0,1].set_title(r'SGOT barycenter $(\gamma = 0.5)$', fontsize=14)
    ax[0,2].set_title(r'$\mathbf{T}^{(1)}$ $(\gamma = 1.0)$', fontsize=14)

    fig.tight_layout()
    plt.savefig('fluid_barycenter.pdf', format='pdf', bbox_inches='tight')

