import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from utils import *


# ========== Minimal GPU Adaptation ==========
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run.py] Using device: {DEVICE}")

def load_on_device(path):
    return torch.load(path, map_location=DEVICE, weights_only=False)
# ============================================

import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from pathlib import Path
import json


parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--dataset_nam", type=str, required=True,
                    help="mimic")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

dataset_nam = args.dataset_nam
idx = args.idx
lam = args.lam

print(f"Using dataset: {dataset_nam}")


ROOT_dir = './'

max_ep = 2000

batch_size = 128

arch = "transformer"

seed = 2025 + idx
torch.manual_seed(seed)
np.random.seed(seed)




# === Dataset ===
class MIMICDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir: str):
        self.root = Path(root_dir)
        self.index_file = self.root / "index.jsonl"
        self.meta_file  = self.root / "meta.json"
        if not self.index_file.exists():
            raise FileNotFoundError(f"Index file {self.index_file} not found")
        if not self.meta_file.exists():
            raise FileNotFoundError(f"Meta file {self.meta_file} not found")

        # load meta + index
        with open(self.meta_file) as f:
            self.meta = json.load(f)
        with open(self.index_file) as f:
            self.records = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        uid = rec["id"]

        # load each modality if path exists
        tab = None
        timeseries = None
        y = None

        if "tab" in rec:
            tab = torch.load(self.root / rec["tab"], weights_only=False)
        if "timeseries" in rec:
            timeseries = torch.load(self.root / rec["timeseries"], weights_only=False)
        if "label" in rec:
            y = torch.load(self.root / rec["label"], weights_only=False)

        return tab, timeseries, y, uid

# === Loaders with GPU optimization ===
from torch.utils.data import DataLoader

loader_args = dict(
    batch_size=32,
)

root_dir = ROOT_dir + "FactorCL/Multibench/MultiBench/"

train_set = MIMICDataset(root_dir + f"{dataset_nam}_train")
valid_set = MIMICDataset(root_dir + f"{dataset_nam}_valid")
test_set  = MIMICDataset(root_dir + f"{dataset_nam}_test")


train_loader = DataLoader(train_set, shuffle=True,
                          **loader_args)
valid_loader = DataLoader(valid_set, shuffle=False,
                          **loader_args)
test_loader  = DataLoader(test_set, shuffle=False,
                          **loader_args)




def pad_and_concat(seq_list, M=None):

    all_3d = all(x.dim() >= 3 for x in seq_list)

    if not all_3d:
        # Just concatenate directly
        return torch.cat(seq_list, dim=0)

    # Otherwise handle 3D tensors with padding
    if M is None:
        M = max(x.size(1) for x in seq_list)

    padded = []
    for x in seq_list:
        B, T, D = x.shape
        if T < M:
            pad_len = M - T
            x = Fnn.pad(x, (0, 0, 0, pad_len))  # Pad sequence length
        elif T > M:
            x = x[:, :M, :]
        padded.append(x)

    return torch.cat(padded, dim=0)  # (sum(B), M, D)


# === Collect batches from train loader ===
input_text_list = []
input_visual_list = []
lab_list = []

for batch in tqdm(train_loader, desc="train collect"):
    visual  = batch[0]     # (B, T, D_text)
    text = batch[1]  # (B, T, D_visual)
    label = batch[2]    # (B, ...)

    input_text_list.append(text)
    input_visual_list.append(visual)
    lab_list.append(label)

# === Collect batches from test loader ===
input_text_test_list = []
input_visual_test_list = []
lab_test_list = []

for batch in tqdm(test_loader, desc="test collect"):
    visual  = batch[0]
    text = batch[1]
    label = batch[2]

    input_text_test_list.append(text)
    input_visual_test_list.append(visual)
    lab_test_list.append(label)

# === Determine maximum sequence lengths across train+test ===
M_text = max(max(x.size(1) for x in input_text_list),
             max(x.size(1) for x in input_text_test_list))

M_visual = max(max(x.size(1) for x in input_visual_list),
               max(x.size(1) for x in input_visual_test_list))

# === Concatenate & pad ===
lab         = torch.cat(lab_list, dim=0)
input_text  = pad_and_concat(input_text_list, M_text).float().to(DEVICE)
input_visual = pad_and_concat(input_visual_list, M_visual).float().to(DEVICE)

lab_test         = torch.cat(lab_test_list, dim=0)
input_text_test  = pad_and_concat(input_text_test_list, M_text).float().to(DEVICE)    # use same M_text
input_visual_test = pad_and_concat(input_visual_test_list, M_visual).float().to(DEVICE)



## Parameters
print("\nPreparing CLIP training...")

n = input_text.shape[0]        # train samples
n_test = input_text_test.shape[0]  # test samples
N = n + n_test

# Labels: convert to binary
lab1_train = (lab > 0).long()
lab1_test = (lab_test > 0).long()


input_visual = input_visual.unsqueeze(1)
input_visual_test = input_visual_test.unsqueeze(1)

print(input_text.shape, input_visual.shape, input_text_test.shape, input_visual_test.shape)


from train import *

d_x = input_text.shape[2]
d_y = input_visual.shape[2]

tau_lower = 1e-4
outdim = 50
middim = max(d_x, d_y)

# Build models on GPU
if arch == "transformer":
    model_x = Transformer_tens(d_x, outdim).to(DEVICE)
    model_y = Transformer_tens(d_y, outdim).to(DEVICE)
else:
    model_x = NonLinearNetD(d_x, middim, outdim, tau_lower=tau_lower).to(DEVICE)
    model_y = NonLinearNetD(d_y, middim, outdim, tau_lower=tau_lower).to(DEVICE)

print(f"\n== CLIP Training ({arch}) ==\n")

# Save directory
save_dir = os.path.join(ROOT_dir, f"results/{dataset_nam}")
os.makedirs(save_dir, exist_ok=True)

# === Train or load clip results ===
clip_path = f"{save_dir}/clip_results_{arch}_{idx}_{lam}.pt"

if not os.path.exists(clip_path):
    clip_results = train_clip_raw(
        input_text, input_visual,
        model_x, model_y,
        max_epochs=max_ep, batch_size=batch_size,
        lr=1e-4, wd=1e-4,
        tau_fix=1.0, tau_tune=True,
        tau_lr_fac=5, spectral=False,
        device=DEVICE,
    )
    torch.save(clip_results, clip_path)
else:
    clip_results = load_on_device(clip_path)

# === Encode training features ===
print("Encoding representations...")

with torch.no_grad():
    XX_clip = clip_results["model_x"](input_text).to(DEVICE)
    YY_clip = clip_results["model_y"](input_visual).to(DEVICE)

# objective = 'disen'
# objective = 'fact'
# objective = 'recons'

print(XX_clip.shape, YY_clip.shape)



for aug in [True, False]:
    for objective in ['recons', 'fact', 'disen']:

        print(f"\n== Disentangled method: {objective} ===\n")

        arch_disentg = arch
        # lam = 5e-2 if objective == "recons" else 1.0

        save_dir = os.path.join(ROOT_dir, f"results/{dataset_nam}")
        os.makedirs(save_dir, exist_ok=True)

        # ========= For comp='X' =========
        X_raw = input_text             # (n, seq_len, d_x), on DEVICE
        X_embed = XX_clip # re-encode as tensor on DEVICE
        seq_len = input_text.shape[1]

        result_path_X = os.path.join(save_dir, f"result_X_{objective}_{arch}_{idx}_{lam}_aug{aug}.pt")
        if not os.path.exists(result_path_X):
            result_X = train_disentangle_raw(
                X_raw, X_embed,
                outdim=outdim,
                seq_len=seq_len,
                arch=arch_disentg,
                max_epochs=max_ep,
                batch_size=batch_size,
                objective=objective,
                lam=lam,
                aug=aug,
                device=DEVICE,
            )
            torch.save(result_X, result_path_X)
        else:
            result_X = load_on_device(result_path_X)

        # ========= For comp='Y' =========
        Y_raw = input_visual
        Y_embed = YY_clip
        seq_len = input_visual.shape[1]

        result_path_Y = os.path.join(save_dir, f"result_Y_{objective}_{arch}_{idx}_{lam}_aug{aug}.pt")
        if not os.path.exists(result_path_Y):
            result_Y = train_disentangle_raw(
                Y_raw, Y_embed,
                outdim=outdim,
                seq_len=seq_len,
                arch=arch_disentg,
                max_epochs=max_ep,
                batch_size=batch_size,
                objective=objective,
                lam=lam,
                aug=aug,
                device=DEVICE,
            )
            torch.save(result_Y, result_path_Y)
        else:
            result_Y = load_on_device(result_path_Y)
