import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from utils import *



DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run.py] Using device: {DEVICE}")

def load_on_device(path):
    return torch.load(path, map_location=DEVICE, weights_only=False)

import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from pathlib import Path
import json


parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--dataset_nam", type=str, required=True,
                    help="Name of the dataset, e.g. mosi, mosei, etc.")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

dataset_nam = args.dataset_nam
idx = args.idx
lam = args.lam

print(f"Using dataset: {dataset_nam}")


ROOT_dir = './'

max_ep = 2000

batch_size = 128

arch = "transformer"

seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)




# === Dataset ===
class MOSILocalDataset(Dataset):
    def __init__(self, root, modalities=("text", "audio", "vision"), use_label=True):
        self.root = Path(root)
        self.modalities = modalities
        self.use_label = use_label
        with open(self.root / "index.jsonl") as f:
            self.records = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        out = {"id": rec["id"]}

        for m in self.modalities:
            if m in rec:
                out[m] = torch.load(self.root / rec[m], map_location="cpu", weights_only=False)  # Keep on CPU

        if self.use_label and "label" in rec:
            out["label"] = torch.load(self.root / rec["label"], map_location="cpu", weights_only=False)

        return out


from torch.nn.utils.rnn import pad_sequence

def mosi_collate_fn(batch, modalities=("text", "audio", "vision")):
    out = {"id": [b["id"] for b in batch]}

    for m in modalities:
        if m in batch[0]:
            seqs = [b[m] for b in batch]
            out[m] = pad_sequence(seqs, batch_first=True).to(DEVICE, non_blocking=True)

    if "label" in batch[0]:
        out["label"] = torch.stack([b["label"] for b in batch]).to(DEVICE, non_blocking=True)

    return out


from torch.utils.data import DataLoader

loader_args = dict(
    batch_size=32,
)

root_dir = ROOT_dir + "FactorCL/Multibench/MultiBench/"

train_set = MOSILocalDataset(root_dir + f"{dataset_nam}_train")
valid_set = MOSILocalDataset(root_dir + f"{dataset_nam}_valid")
test_set  = MOSILocalDataset(root_dir + f"{dataset_nam}_test")

train_loader = DataLoader(train_set, shuffle=True,
                          collate_fn=lambda b: mosi_collate_fn(b, modalities=("text", "audio", "vision")),
                          **loader_args)
valid_loader = DataLoader(valid_set, shuffle=False,
                          collate_fn=lambda b: mosi_collate_fn(b, modalities=("text", "audio", "vision")),
                          **loader_args)
test_loader  = DataLoader(test_set, shuffle=False,
                          collate_fn=lambda b: mosi_collate_fn(b, modalities=("text", "audio", "vision")),
                          **loader_args)




def pad_and_concat(seq_list, M=None):
    
    if M is None:
        M = max(x.size(1) for x in seq_list)

    padded = []
    for x in seq_list:
        B, T, D = x.shape
        if T < M:
            pad_len = M - T
            x = F.pad(x, (0, 0, 0, pad_len))  # Pad second dim (T)
        elif T > M:
            x = x[:, :M, :]
        padded.append(x)

    return torch.cat(padded, dim=0)  # (sum(B), M, D)



input_text_list = []
input_visual_list = []
lab_list = []

for batch in tqdm(train_loader, desc="train collect"):
    text  = batch['text']     # (B, T, D_text)
    visual = batch['vision']  # (B, T, D_visual)
    label = batch['label']    # (B, ...)

    input_text_list.append(text)
    input_visual_list.append(visual)
    lab_list.append(label)


input_text_test_list = []
input_visual_test_list = []
lab_test_list = []

for batch in tqdm(test_loader, desc="test collect"):
    text  = batch['text']
    visual = batch['vision']
    label = batch['label']

    input_text_test_list.append(text)
    input_visual_test_list.append(visual)
    lab_test_list.append(label)


M_text = max(max(x.size(1) for x in input_text_list),
             max(x.size(1) for x in input_text_test_list))

M_visual = max(max(x.size(1) for x in input_visual_list),
               max(x.size(1) for x in input_visual_test_list))


lab         = torch.cat(lab_list, dim=0)
input_text  = pad_and_concat(input_text_list, M_text).float()
input_visual = pad_and_concat(input_visual_list, M_visual).float()

lab_test         = torch.cat(lab_test_list, dim=0)
input_text_test  = pad_and_concat(input_text_test_list, M_text).float()    # use same M_text
input_visual_test = pad_and_concat(input_visual_test_list, M_visual).float()



## Parameters
print("\nPreparing CLIP training...")

n = input_text.shape[0]        # train samples
n_test = input_text_test.shape[0]  # test samples
N = n + n_test

# Labels: convert to binary
if dataset_nam == "mosi" or dataset_nam == "mosei":
    lab1_train = (lab >= 0).long()
    lab1_test = (lab_test >= 0).long()
else:
    lab1_train = (lab != -1).long()
    lab1_test = (lab_test != -1).long()

print(input_text.shape, input_visual.shape, input_text_test.shape, input_visual_test.shape)

from train import *

d_x = input_text.shape[2]
d_y = input_visual.shape[2]
seq_len = input_text.shape[1]

tau_lower = 1e-4
outdim = 50
middim = max(d_x, d_y)


for objective in ['recons', 'fact', 'disen']:

    print(f"\n== Disentangled method: {objective} ===\n")

    arch_disentg = arch


    print(f"\n== CLIP Training ({arch}) ==\n")
    # Save directory
    save_dir = os.path.join(ROOT_dir, f"results/{dataset_nam}")
    os.makedirs(save_dir, exist_ok=True)

    # === Train or load clip results ===
    clip_path = f"{save_dir}/clip_results_{arch}_{idx}_{lam}.pt"

    clip_results = load_on_device(clip_path)

    # === Encode training features ===
    print("Encoding representations...")


    # ========= For comp='X' =========
    X_raw = input_text.to(DEVICE)             
    X_test_raw = input_text_test.to(DEVICE)


    result_path_X = os.path.join(save_dir, f"result_X_{objective}_{arch}_{idx}_{lam}.pt")
    
    result_X = load_on_device(result_path_X)

    # ========= For comp='Y' =========
    Y_raw = input_visual.to(DEVICE)
    Y_test_raw = input_visual_test.to(DEVICE)


    result_path_Y = os.path.join(save_dir, f"result_Y_{objective}_{arch}_{idx}_{lam}.pt")
    
    result_Y = load_on_device(result_path_Y)


    # Move models to DEVICE
    clip_results['model_x'] = clip_results['model_x'].to(DEVICE)
    clip_results['model_y'] = clip_results['model_y'].to(DEVICE)

    # Inference on DEVICE
    with torch.no_grad():
        X_clip = clip_results["model_x"](X_raw).to(DEVICE)
        Y_clip = clip_results["model_y"](Y_raw).to(DEVICE)
        X_clip_test = clip_results['model_x'](X_test_raw)
        Y_clip_test = clip_results['model_y'](Y_test_raw)

    
    # Labels (kept as-is for sklearn)
    lab2_train, lab2_test = lab1_train, lab1_test

    # Save results
    save_dir = f'./results/{dataset_nam}'
    os.makedirs(save_dir, exist_ok=True)

    file_dir = f'{save_dir}/result_post_{objective}_{arch}_{idx}_{lam}.pt'

    if not os.path.exists(file_dir):
        result_post = postprocess_disentangle_joint_raw(
            result_X['models'], result_Y['models'],
            X_raw, X_test_raw, X_clip, X_clip_test,
            Y_raw, Y_test_raw, Y_clip, Y_clip_test,
            lab1_train, lab1_test, lab2_train, lab2_test,
            device=DEVICE
        )
        torch.save(result_post, file_dir)
    else:
        result_post = load_on_device(file_dir)
        

    print("=== Classification Accuracies ===")
    for label, scores in result_post['accs'].items():
        print(f"{label}: [clip]={scores['clip']:.4f}")
        print(f"{label}: [clip + disentangled]={scores['concat']:.4f}")

