import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from utils import *


# ========== Minimal GPU Adaptation ==========
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run.py] Using device: {DEVICE}")

def load_on_device(path):
    return torch.load(path, map_location=DEVICE, weights_only=False)
# ============================================

import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from pathlib import Path
import json


parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--dataset_nam", type=str, required=True,
                    help="Name of the dataset, e.g. mosi, mosei, etc.")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

dataset_nam = args.dataset_nam
idx = args.idx
lam = args.lam

print(f"Using dataset: {dataset_nam}")


ROOT_dir = './'

max_ep = 2000

batch_size = 128

arch = "transformer"

seed = 2025 + idx
torch.manual_seed(seed)
np.random.seed(seed)




# === Dataset ===
class MOSILocalDataset(Dataset):
    def __init__(self, root, modalities=("text", "audio", "vision"), use_label=True):
        self.root = Path(root)
        self.modalities = modalities
        self.use_label = use_label
        with open(self.root / "index.jsonl") as f:
            self.records = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        out = {"id": rec["id"]}

        for m in self.modalities:
            if m in rec:
                out[m] = torch.load(self.root / rec[m], map_location="cpu", weights_only=False)  # Keep on CPU

        if self.use_label and "label" in rec:
            out["label"] = torch.load(self.root / rec["label"], map_location="cpu", weights_only=False)

        return out

# === Collate function that moves to GPU ===
from torch.nn.utils.rnn import pad_sequence

def mosi_collate_fn(batch, modalities=("text", "audio", "vision")):
    out = {"id": [b["id"] for b in batch]}

    for m in modalities:
        if m in batch[0]:
            seqs = [b[m] for b in batch]
            out[m] = pad_sequence(seqs, batch_first=True).to(DEVICE, non_blocking=True)

    if "label" in batch[0]:
        out["label"] = torch.stack([b["label"] for b in batch]).to(DEVICE, non_blocking=True)

    return out

# === Loaders with GPU optimization ===
from torch.utils.data import DataLoader

loader_args = dict(
    batch_size=32,
)

root_dir = ROOT_dir + "FactorCL/Multibench/MultiBench/"

train_set = MOSILocalDataset(root_dir + f"{dataset_nam}_train")
valid_set = MOSILocalDataset(root_dir + f"{dataset_nam}_valid")
test_set  = MOSILocalDataset(root_dir + f"{dataset_nam}_test")

train_loader = DataLoader(train_set, shuffle=True,
                          collate_fn=lambda b: mosi_collate_fn(b, modalities=("text", "audio", "vision")),
                          **loader_args)
valid_loader = DataLoader(valid_set, shuffle=False,
                          collate_fn=lambda b: mosi_collate_fn(b, modalities=("text", "audio", "vision")),
                          **loader_args)
test_loader  = DataLoader(test_set, shuffle=False,
                          collate_fn=lambda b: mosi_collate_fn(b, modalities=("text", "audio", "vision")),
                          **loader_args)




def pad_and_concat(seq_list, M=None):

    if M is None:
        M = max(x.size(1) for x in seq_list)

    padded = []
    for x in seq_list:
        B, T, D = x.shape
        if T < M:
            pad_len = M - T
            x = F.pad(x, (0, 0, 0, pad_len))  
        elif T > M:
            x = x[:, :M, :]
        padded.append(x)

    return torch.cat(padded, dim=0)  # (sum(B), M, D)


# === Collect batches from train loader ===
input_text_list = []
input_visual_list = []
lab_list = []

for batch in tqdm(train_loader, desc="train collect"):
    text  = batch['text']     # (B, T, D_text)
    visual = batch['vision']  # (B, T, D_visual)
    label = batch['label']    # (B, ...)

    input_text_list.append(text)
    input_visual_list.append(visual)
    lab_list.append(label)

# === Collect batches from test loader ===
input_text_test_list = []
input_visual_test_list = []
lab_test_list = []

for batch in tqdm(test_loader, desc="test collect"):
    text  = batch['text']
    visual = batch['vision']
    label = batch['label']

    input_text_test_list.append(text)
    input_visual_test_list.append(visual)
    lab_test_list.append(label)

# === Determine maximum sequence lengths across train+test ===
M_text = max(max(x.size(1) for x in input_text_list),
             max(x.size(1) for x in input_text_test_list))

M_visual = max(max(x.size(1) for x in input_visual_list),
               max(x.size(1) for x in input_visual_test_list))

# === Concatenate & pad ===
lab         = torch.cat(lab_list, dim=0)
input_text  = pad_and_concat(input_text_list, M_text).float()
input_visual = pad_and_concat(input_visual_list, M_visual).float()

lab_test         = torch.cat(lab_test_list, dim=0)
input_text_test  = pad_and_concat(input_text_test_list, M_text).float()    # use same M_text
input_visual_test = pad_and_concat(input_visual_test_list, M_visual).float()



## Parameters
print("\nPreparing CLIP training...")

n = input_text.shape[0]        # train samples
n_test = input_text_test.shape[0]  # test samples
N = n + n_test

# Labels: convert to binary
lab1_train = (lab > 0).long()
lab1_test = (lab_test > 0).long()

print(input_text.shape, input_visual.shape, input_text_test.shape, input_visual_test.shape)

from train import *

d_x = input_text.shape[2]
d_y = input_visual.shape[2]

tau_lower = 1e-4
outdim = 50
middim = max(d_x, d_y)

# Build models on GPU
if arch == "transformer":
    model_x = Transformer_tens(d_x, outdim).to(DEVICE)
    model_y = Transformer_tens(d_y, outdim).to(DEVICE)
else:
    model_x = NonLinearNetD(d_x, middim, outdim, tau_lower=tau_lower).to(DEVICE)
    model_y = NonLinearNetD(d_y, middim, outdim, tau_lower=tau_lower).to(DEVICE)

print(f"\n== CLIP Training ({arch}) ==\n")

# Save directory
save_dir = os.path.join(ROOT_dir, f"results/{dataset_nam}")
os.makedirs(save_dir, exist_ok=True)

# === Train or load clip results ===
clip_path = f"{save_dir}/clip_results_{arch}_{idx}_{lam}.pt"


if not os.path.exists(clip_path):
    clip_results = train_clip_raw(
        input_text, input_visual,
        model_x, model_y,
        max_epochs=max_ep, batch_size=batch_size,
        lr=1e-4, wd=1e-4,
        tau_fix=1.0, tau_tune=True,
        tau_lr_fac=5, spectral=False,
        device=DEVICE,
    )
    torch.save(clip_results, clip_path)
else:
    clip_results = load_on_device(clip_path)

# === Encode training features ===
print("Encoding representations...")

with torch.no_grad():
    XX_clip = clip_results["model_x"](input_text).to(DEVICE)
    YY_clip = clip_results["model_y"](input_visual).to(DEVICE)

# objective = 'disen'
# objective = 'fact'
# objective = 'recons'


# for aug in [True, False]:
for objective in ['recons', 'fact', 'disen']:

    print(f"\n== Disentangled method: {objective} ===\n")

    arch_disentg = arch
    # lam = 5e-2 if objective == "recons" else 1.0

    save_dir = os.path.join(ROOT_dir, f"results/{dataset_nam}")
    os.makedirs(save_dir, exist_ok=True)

    # ========= For comp='X' =========
    X_raw = input_text             # (n, seq_len, d_x), on DEVICE
    X_embed = XX_clip # re-encode as tensor on DEVICE
    seq_len = input_text.shape[1]

    # result_path_X = os.path.join(save_dir, f"result_X_{objective}_{arch}_{idx}_{lam}_aug{aug}.pt")
    result_path_X = os.path.join(save_dir, f"result_X_{objective}_{arch}_{idx}_{lam}.pt")
    if not os.path.exists(result_path_X):
        result_X = train_disentangle_raw(
            X_raw, X_embed,
            outdim=outdim,
            seq_len=seq_len,
            arch=arch_disentg,
            max_epochs=max_ep,
            batch_size=batch_size,
            objective=objective,
            lam=lam,
            aug=False,
            device=DEVICE,
        )
        torch.save(result_X, result_path_X)
    else:
        result_X = load_on_device(result_path_X)

    # ========= For comp='Y' =========
    Y_raw = input_visual
    Y_embed = YY_clip
    seq_len = input_visual.shape[1]

    result_path_Y = os.path.join(save_dir, f"result_Y_{objective}_{arch}_{idx}_{lam}.pt")
    # result_path_Y = os.path.join(save_dir, f"result_Y_{objective}_{arch}_{idx}_{lam}_aug{aug}.pt")
    if not os.path.exists(result_path_Y):
        result_Y = train_disentangle_raw(
            Y_raw, Y_embed,
            outdim=outdim,
            seq_len=seq_len,
            arch=arch_disentg,
            max_epochs=max_ep,
            batch_size=batch_size,
            objective=objective,
            lam=lam,
            aug=False,
            device=DEVICE,
        )
        torch.save(result_Y, result_path_Y)
    else:
        result_Y = load_on_device(result_path_Y)

