import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from utils import *

device = "cpu" if not torch.cuda.is_available() else torch.device("cuda")
print("Device is :: {}...".format(device))



parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

idx = args.idx
lam = args.lam

seed = 2025
torch.manual_seed(seed + idx)
np.random.seed(seed + idx)

def load_on_device(path):
    return torch.load(path, map_location=device, weights_only=False)


for arch in ['deep']:

    for norm_ in [True]:


        print(f"Using CITE-seq dataset: normalization is {norm_}, training arch {arch}")

        root_dir = "../citeseq_data"


        data_dir = root_dir + "/data_unnormalized"





        def data_gen_citeseq_dtm(root_dir, n):

            dir1 = f'{root_dir}/citeseq/rna_pca.csv'
            dir2 = f'{root_dir}/citeseq/adt_pca.csv'

            df1 = pd.read_csv(dir1)
            X_all = np.array(df1.drop(df1.columns[0], axis=1))
            df2 = pd.read_csv(dir2)
            Y_all = np.array(df2.drop(df2.columns[0], axis=1))

            lab1_dir = f'{root_dir}/citeseq/lab1.csv'
            lab2_dir = f'{root_dir}/citeseq/lab2.csv'
            lab1 = pd.read_csv(lab1_dir)
            lab1 = np.array(lab1.drop(lab1.columns[0], axis=1)).ravel()
            lab2 = pd.read_csv(lab2_dir)
            lab2 = np.array(lab2.drop(lab2.columns[0], axis=1)).ravel()

            d_x = X_all.shape[1]
            d_y = Y_all.shape[1]
            d_z = 0

            idx = np.arange(X_all.shape[0])
            np.random.shuffle(idx)


            X = X_all[:n,:]
            Y = Y_all[:n,:]
            X_test = X_all[n:,:]
            Y_test = Y_all[n:,:]

            lab1_train = lab1[:n]
            lab2_train = lab2[:n]
            lab1_test = lab1[n:]
            lab2_test = lab2[n:]

            rna_wts = pd.read_csv(f"{root_dir}/citeseq/rna_wts.csv")
            wts_train = rna_wts[:n]['x']
            wts_test = rna_wts[n:]['x']

            return X, Y, X_test, Y_test, lab1_train, lab1_test, lab2_train, lab2_test, wts_train, wts_test


        import numpy as np

        def normalize_datasets(XX, YY, XX_test, YY_test, norm_=True):
            
            if not norm_:
                return XX, YY, XX_test, YY_test

            # compute stats from training
            mean_X, std_X = XX.mean(axis=0), XX.std(axis=0, ddof=0)
            mean_Y, std_Y = YY.mean(axis=0), YY.std(axis=0, ddof=0)

            # prevent divide-by-zero (set zero stds to 1)
            std_X[std_X == 0] = 1.0
            std_Y[std_Y == 0] = 1.0

            # apply normalization
            XX_norm      = (XX - mean_X) / std_X
            YY_norm      = (YY - mean_Y) / std_Y
            XX_test_norm = (XX_test - mean_X) / std_X
            YY_test_norm = (YY_test - mean_Y) / std_Y

            return XX_norm, YY_norm, XX_test_norm, YY_test_norm


        ## parameters
        n = 15000              # Number of training samples
        n_test = 2000         # Number of test samples

        # data generation
        print("\nLoading data...")

        ## dataset
        dataset_nam = "citeseq"


        is_pca = True
        XX, YY, XX_test, YY_test, lab1_train, lab1_test, lab2_train, lab2_test, wts_train, wts_test = data_gen_citeseq_dtm(data_dir, n,)

        n = XX.shape[0]
        n_test = XX_test.shape[0]
        N = n + n_test


        XX, YY, XX_test, YY_test = normalize_datasets(XX, YY, XX_test, YY_test, norm_)

        print(XX.shape, YY.shape, XX_test.shape, YY_test.shape)




        from train import *

        d_x = XX.shape[1]
        d_y = YY.shape[1]
        tau_lower = 1e-4
        device = 'cpu'
        max_ep = 2000
        batch_size = 128


        outdim = 50
        # middim = max(d_x, d_y)
        middim = 50

        if arch == "transformer":
            # model_x = TransformerEncoderNet(d_x, embed_dim=middim, output_dim=outdim, tau_lower=tau_lower, num_heads=5, ff_dim=64, num_layers=2).to(device)
            # model_y = TransformerEncoderNet(d_y, embed_dim=middim, output_dim=outdim, tau_lower=tau_lower, num_heads=5, ff_dim=64, num_layers=2).to(device)
            model_x = Transformer_ma(d_x, middim=middim, dim=outdim).to(device)
            model_y = Transformer_ma(d_y, middim=middim, dim=outdim).to(device)
        else:
            model_x = NonLinearNetD(d_x, middim, outdim, tau_lower=tau_lower).to(device)
            model_y = NonLinearNetD(d_y, middim, outdim, tau_lower=tau_lower).to(device)


        # =============================
        #        CLIP Training
        # =============================

        print(f"\n== CLIP Training ({arch}) ==\n")

        save_dir = f'./results/{dataset_nam}_norm{norm_}'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        clip_path = f'{save_dir}/clip_results_{arch}_{idx}.pt'

        if not os.path.exists(clip_path):
            clip_results = train_clip(
                XX, YY,
                model_x, model_y, 
                max_epochs=max_ep, batch_size=batch_size, lr=1e-4, wd=1e-4,
                tau_fix=1.0, tau_tune=True, tau_lr_fac=5, spectral=False, device='cpu'
            )
            torch.save(clip_results, clip_path)
        else:
            clip_results = load_on_device(clip_path)



        # =============================
        #        Disentanglement
        # =============================


        ## model architecture

        XX_clip, YY_clip = clip_results['model_x'](torch.Tensor(XX)).detach().numpy(), clip_results['model_y'](torch.Tensor(YY)).detach().numpy()


        print(f"training with lam={lam}")

        for objective in ['fact', 'disen', 'recons']:

            print(f"\n== Disentangled method: {objective} ===\n")

            arch_disentg = arch


            save_dir = f'./results/{dataset_nam}_norm{norm_}'
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)


            result_path_X = f'{save_dir}/result_X_{objective}_{arch}_{idx}_{lam}.pt'

            if not os.path.exists(result_path_X):
                result_X = train_disentangle(
                    XX, XX_clip,
                    outdim=outdim, arch=arch_disentg, max_epochs=max_ep, batch_size=batch_size,
                    objective=objective, lam=lam
                )
                torch.save(result_X, result_path_X)
            else:
                result_X = load_on_device(result_path_X)



            result_path_Y = f'{save_dir}/result_Y_{objective}_{arch}_{idx}_{lam}.pt'

            if not os.path.exists(result_path_Y):
                result_Y = train_disentangle(
                    YY, YY_clip, 
                    outdim=outdim, arch=arch_disentg, max_epochs=max_ep, batch_size=batch_size,
                    objective=objective, lam=lam
                )
                torch.save(result_Y, result_path_Y)
            else:
                result_Y = load_on_device(result_path_Y)



