import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from utils import *



DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run.py] Using device: {DEVICE}")

def load_on_device(path):
    return torch.load(path, map_location=DEVICE, weights_only=False)


import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from pathlib import Path
import json


parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--dataset_nam", type=str, required=True,
                    help="Name of the dataset, e.g. mosi, mosei, etc.")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

dataset_nam = args.dataset_nam
idx = args.idx
lam = args.lam

print(f"Using dataset: {dataset_nam}")


ROOT_dir = './'

max_ep = 2000

batch_size = 128

arch = "transformer"

seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)






# === Dataset ===
class MIMICDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir: str):
        self.root = Path(root_dir)
        self.index_file = self.root / "index.jsonl"
        self.meta_file  = self.root / "meta.json"
        if not self.index_file.exists():
            raise FileNotFoundError(f"Index file {self.index_file} not found")
        if not self.meta_file.exists():
            raise FileNotFoundError(f"Meta file {self.meta_file} not found")

        # load meta + index
        with open(self.meta_file) as f:
            self.meta = json.load(f)
        with open(self.index_file) as f:
            self.records = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        uid = rec["id"]

        tab = None
        timeseries = None
        y = None

        if "tab" in rec:
            tab = torch.load(self.root / rec["tab"], weights_only=False)
        if "timeseries" in rec:
            timeseries = torch.load(self.root / rec["timeseries"], weights_only=False)
        if "label" in rec:
            y = torch.load(self.root / rec["label"], weights_only=False)

        return tab, timeseries, y, uid

from torch.utils.data import DataLoader

loader_args = dict(
    batch_size=32,
)

root_dir = ROOT_dir + "FactorCL/Multibench/MultiBench/"

train_set = MIMICDataset(root_dir + f"{dataset_nam}_all_train")
test_set  = MIMICDataset(root_dir + f"{dataset_nam}_all_test")


train_loader = DataLoader(train_set, shuffle=True,
                          **loader_args)
test_loader  = DataLoader(test_set, shuffle=False,
                          **loader_args)





def pad_and_concat(seq_list, M=None):

    all_3d = all(x.dim() >= 3 for x in seq_list)

    if not all_3d:
        return torch.cat(seq_list, dim=0)

    if M is None:
        M = max(x.size(1) for x in seq_list)

    padded = []
    for x in seq_list:
        B, T, D = x.shape
        if T < M:
            pad_len = M - T
            x = Fnn.pad(x, (0, 0, 0, pad_len))  # Pad sequence length
        elif T > M:
            x = x[:, :M, :]
        padded.append(x)

    return torch.cat(padded, dim=0)  # (sum(B), M, D)




input_text_list = []
input_visual_list = []
lab_list = []

for batch in tqdm(train_loader, desc="train collect"):
    visual  = batch[0]     # (B, T, D_text)
    text = batch[1]  # (B, T, D_visual)
    label = batch[2]    # (B, ...)

    input_text_list.append(text)
    input_visual_list.append(visual)
    lab_list.append(label)


input_text_test_list = []
input_visual_test_list = []
lab_test_list = []

for batch in tqdm(test_loader, desc="test collect"):
    visual  = batch[0]
    text = batch[1]
    label = batch[2]

    input_text_test_list.append(text)
    input_visual_test_list.append(visual)
    lab_test_list.append(label)


M_text = max(max(x.size(1) for x in input_text_list),
             max(x.size(1) for x in input_text_test_list))

M_visual = max(max(x.size(1) for x in input_visual_list),
               max(x.size(1) for x in input_visual_test_list))


lab         = torch.cat(lab_list, dim=0)
input_text  = pad_and_concat(input_text_list, M_text).float().to(DEVICE)
input_visual = pad_and_concat(input_visual_list, M_visual).float().to(DEVICE)

lab_test         = torch.cat(lab_test_list, dim=0)
input_text_test  = pad_and_concat(input_text_test_list, M_text).float().to(DEVICE)    # use same M_text
input_visual_test = pad_and_concat(input_visual_test_list, M_visual).float().to(DEVICE)



## Parameters
print("\nPreparing CLIP training...")

n = input_text.shape[0]        # train samples
n_test = input_text_test.shape[0]  # test samples
N = n + n_test

# Labels: convert to binary
lab1_train_all = lab.long()
lab1_test_all = lab_test.long()


input_visual = input_visual.unsqueeze(1)
input_visual_test = input_visual_test.unsqueeze(1)

print(input_text.shape, input_visual.shape, input_text_test.shape, input_visual_test.shape)


from train import *

d_x = input_text.shape[2]
d_y = input_visual.shape[2]

tau_lower = 1e-4
outdim = 50
middim = max(d_x, d_y)


for task_id in range(21):


    lab1_train, lab1_test = lab1_train_all[:,task_id], lab1_test_all[:,task_id]
    lab2_train, lab2_test = lab1_train, lab1_test

    for objective in ['recons', 'fact', 'disen']:

        print(f"\n== Disentangled method: {objective} ===\n")

        arch_disentg = arch


        print(f"\n== CLIP Training ({arch}) ==\n")
        # Save directory
        save_dir = os.path.join(ROOT_dir, f"results/{dataset_nam}")
        os.makedirs(save_dir, exist_ok=True)

        # === Train or load clip results ===
        clip_path = f"{save_dir}/clip_results_{arch}_{idx}_{lam}.pt"

        clip_results = load_on_device(clip_path)

        # === Encode training features ===
        print("Encoding representations...")


        # ========= For comp='X' =========
        X_raw = input_text.to(DEVICE)             # (n, seq_len, d_x), on DEVICE
        X_test_raw = input_text_test.to(DEVICE)

        result_path_X = os.path.join(save_dir, f"result_X_{objective}_{arch}_{idx}_{lam}.pt")
        
        result_X = load_on_device(result_path_X)

        # ========= For comp='Y' =========
        Y_raw = input_visual.to(DEVICE)
        Y_test_raw = input_visual_test.to(DEVICE)

        result_path_Y = os.path.join(save_dir, f"result_Y_{objective}_{arch}_{idx}_{lam}.pt")
        
        result_Y = load_on_device(result_path_Y)


        # Move models to DEVICE
        clip_results['model_x'] = clip_results['model_x'].to(DEVICE)
        clip_results['model_y'] = clip_results['model_y'].to(DEVICE)

        # Inference on DEVICE
        with torch.no_grad():
            X_clip = clip_results["model_x"](X_raw).to(DEVICE)
            Y_clip = clip_results["model_y"](Y_raw).to(DEVICE)
            X_clip_test = clip_results['model_x'](X_test_raw).to(DEVICE)
            Y_clip_test = clip_results['model_y'](Y_test_raw).to(DEVICE)



        # Joint postprocessing
        result_post = postprocess_disentangle_joint_raw(
            result_X['models'], result_Y['models'],
            X_raw, X_test_raw, X_clip, X_clip_test,
            Y_raw, Y_test_raw, Y_clip, Y_clip_test,
            lab1_train, lab1_test, lab2_train, lab2_test,
            device=DEVICE
        )

        print(f"=== Classification Accuracies: task_id = {task_id} ===")
        for label, scores in result_post['accs'].items():
            print(f"{label}: [clip]={scores['clip']:.4f}")
            print(f"{label}: [clip + disentangled]={scores['concat']:.4f}")

        # Save results
        save_dir = f'./results/{dataset_nam}'
        os.makedirs(save_dir, exist_ok=True)

        torch.save(result_post, f'{save_dir}/result_post_{objective}_{arch}_{idx}_{lam}_task{task_id}.pt')


