import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from utils import *


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run.py] Using device: {DEVICE}")

def load_on_device(path):
    return torch.load(path, map_location=DEVICE, weights_only=False)

import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from pathlib import Path
import json


parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--dataset_nam", type=str, required=True,
                    help="Name of the dataset, e.g. mosi, mosei, etc.")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

dataset_nam = args.dataset_nam
idx = args.idx
lam = args.lam

print(f"Using dataset: {dataset_nam}")


ROOT_dir = './'

max_ep = 2000

batch_size = 128

arch = "transformer"

seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)






# === Dataset ===
class MIMICDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir: str):
        self.root = Path(root_dir)
        self.index_file = self.root / "index.jsonl"
        self.meta_file  = self.root / "meta.json"
        if not self.index_file.exists():
            raise FileNotFoundError(f"Index file {self.index_file} not found")
        if not self.meta_file.exists():
            raise FileNotFoundError(f"Meta file {self.meta_file} not found")

        # load meta + index
        with open(self.meta_file) as f:
            self.meta = json.load(f)
        with open(self.index_file) as f:
            self.records = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        uid = rec["id"]

        # load each modality if path exists
        tab = None
        timeseries = None
        y = None

        if "tab" in rec:
            tab = torch.load(self.root / rec["tab"], weights_only=False)
        if "timeseries" in rec:
            timeseries = torch.load(self.root / rec["timeseries"], weights_only=False)
        if "label" in rec:
            y = torch.load(self.root / rec["label"], weights_only=False)

        return tab, timeseries, y, uid


from torch.utils.data import DataLoader

loader_args = dict(
    batch_size=32,
)

root_dir = ROOT_dir + "FactorCL/Multibench/MultiBench/"

train_set = MIMICDataset(root_dir + f"{dataset_nam}_train")
valid_set = MIMICDataset(root_dir + f"{dataset_nam}_valid")
test_set  = MIMICDataset(root_dir + f"{dataset_nam}_test")


train_loader = DataLoader(train_set, shuffle=True,
                          **loader_args)
valid_loader = DataLoader(valid_set, shuffle=False,
                          **loader_args)
test_loader  = DataLoader(test_set, shuffle=False,
                          **loader_args)





def pad_and_concat(seq_list, M=None):

    all_3d = all(x.dim() >= 3 for x in seq_list)

    if not all_3d:
        return torch.cat(seq_list, dim=0)

    if M is None:
        M = max(x.size(1) for x in seq_list)

    padded = []
    for x in seq_list:
        B, T, D = x.shape
        if T < M:
            pad_len = M - T
            x = Fnn.pad(x, (0, 0, 0, pad_len))  
        elif T > M:
            x = x[:, :M, :]
        padded.append(x)

    return torch.cat(padded, dim=0)  



input_text_list = []
input_visual_list = []
lab_list = []

for batch in tqdm(train_loader, desc="train collect"):
    visual  = batch[0]     
    text = batch[1]  
    label = batch[2]    

    input_text_list.append(text)
    input_visual_list.append(visual)
    lab_list.append(label)


input_text_test_list = []
input_visual_test_list = []
lab_test_list = []

for batch in tqdm(test_loader, desc="test collect"):
    visual  = batch[0]
    text = batch[1]
    label = batch[2]

    input_text_test_list.append(text)
    input_visual_test_list.append(visual)
    lab_test_list.append(label)


M_text = max(max(x.size(1) for x in input_text_list),
             max(x.size(1) for x in input_text_test_list))

M_visual = max(max(x.size(1) for x in input_visual_list),
               max(x.size(1) for x in input_visual_test_list))


lab         = torch.cat(lab_list, dim=0)
input_text  = pad_and_concat(input_text_list, M_text).float().to(DEVICE)
input_visual = pad_and_concat(input_visual_list, M_visual).float().to(DEVICE)

lab_test         = torch.cat(lab_test_list, dim=0)
input_text_test  = pad_and_concat(input_text_test_list, M_text).float().to(DEVICE)    # use same M_text
input_visual_test = pad_and_concat(input_visual_test_list, M_visual).float().to(DEVICE)



## Parameters
print("\nPreparing CLIP training...")

n = input_text.shape[0]        # train samples
n_test = input_text_test.shape[0]  # test samples
N = n + n_test


lab1_train = lab.long()
lab1_test = lab_test.long()


input_visual = input_visual.unsqueeze(1)
input_visual_test = input_visual_test.unsqueeze(1)

print(input_text.shape, input_visual.shape, input_text_test.shape, input_visual_test.shape)


from train import *

d_x = input_text.shape[2]
d_y = input_visual.shape[2]

tau_lower = 1e-4
outdim = 50
middim = max(d_x, d_y)




for objective in ['recons', 'fact', 'disen']:

    print(f"\n== Disentangled method: {objective} ===\n")

    arch_disentg = arch


    print(f"\n== CLIP Training ({arch}) ==\n")
    # Save directory
    save_dir = os.path.join(ROOT_dir, f"results/{dataset_nam}")
    os.makedirs(save_dir, exist_ok=True)

    # === Train or load clip results ===
    clip_path = f"{save_dir}/clip_results_{arch}_{idx}_{lam}.pt"

    clip_results = load_on_device(clip_path)

    # === Encode training features ===
    print("Encoding representations...")


    # ========= For comp='X' =========
    X_raw = input_text.to(DEVICE)             # (n, seq_len, d_x), on DEVICE
    X_test_raw = input_text_test.to(DEVICE)

    result_path_X = os.path.join(save_dir, f"result_X_{objective}_{arch}_{idx}_{lam}.pt")
    
    result_X = load_on_device(result_path_X)

    # ========= For comp='Y' =========
    Y_raw = input_visual.to(DEVICE)
    Y_test_raw = input_visual_test.to(DEVICE)

    result_path_Y = os.path.join(save_dir, f"result_Y_{objective}_{arch}_{idx}_{lam}.pt")
    
    result_Y = load_on_device(result_path_Y)


    # Move models to DEVICE
    clip_results['model_x'] = clip_results['model_x'].to(DEVICE)
    clip_results['model_y'] = clip_results['model_y'].to(DEVICE)

    # Inference on DEVICE
    with torch.no_grad():
        X_clip = clip_results["model_x"](X_raw).to(DEVICE)
        Y_clip = clip_results["model_y"](Y_raw).to(DEVICE)
        X_clip_test = clip_results['model_x'](X_test_raw).to(DEVICE)
        Y_clip_test = clip_results['model_y'](Y_test_raw).to(DEVICE)

    

    lab2_train, lab2_test = lab1_train, lab1_test

    
    result_post = postprocess_disentangle_joint_raw(
        result_X['models'], result_Y['models'],
        X_raw, X_test_raw, X_clip, X_clip_test,
        Y_raw, Y_test_raw, Y_clip, Y_clip_test,
        lab1_train, lab1_test, lab2_train, lab2_test,
        device=DEVICE
    )

    print("=== Classification Accuracies ===")
    for label, scores in result_post['accs'].items():
        print(f"{label}: [clip]={scores['clip']:.4f}")
        print(f"{label}: [clip + disentangled]={scores['concat']:.4f}")

    # Save results
    save_dir = f'./results/{dataset_nam}'
    os.makedirs(save_dir, exist_ok=True)

    torch.save(result_post, f'{save_dir}/result_post_{objective}_{arch}_{idx}_{lam}.pt')


