import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

import skdim
from skdim.id import MLE, KNN, lPCA, FisherS

from utils import *

ROOT_dir= "./"


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run.py] Using device: {DEVICE}")

def load_on_device(path):
    return torch.load(path, map_location=DEVICE, weights_only=False)


import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from pathlib import Path
import json


parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--dataset_nam", type=str, required=True,
                    help="setting x")
parser.add_argument("--arch", type=str, required=True,
                    help="deep or transformer")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

dataset_nam = args.dataset_nam
idx = args.idx
lam = args.lam
arch = args.arch




def gen_simu_data(N, n, d_x, d_y, d_z, sett):

    Y_all = np.random.multivariate_normal(mean=np.zeros(d_y), cov=np.eye(d_y), size=N)

    noise_all = np.random.multivariate_normal(mean=np.zeros(d_x), cov=np.eye(d_x), size=N)

    A = np.zeros((d_y, d_x))
    for i in range(min(d_x, d_y)):
        A[i, i] = 1.0

    X_all = Y_all @ A + eps * noise_all

    if sett == 'setting6' or sett == 'setting7':
        if d_z < d_x:
            choices = np.array([-1.0, -0.5, 0.5, 1.0, -np.sqrt(7)/2, np.sqrt(7)/2])
            X_all[:, d_z:] = np.random.choice(choices, size=(N, d_x - d_z))

    A_f = np.zeros((d_x, d_z))
    for i in range(min(d_z, d_x)):
        A_f[i, i] = 1.0

    F0_all = X_all @ A_f

    if sett == 'setting1':
        F_all = 0.5 * F0_all + 0.2 * np.sin(F0_all) + 0.2 + (F0_all)**3
    elif sett == 'setting4':
        F_all = F0_all
        X_all[:,2] = 0.2 * (F0_all[:,0] + F0_all[:,1])
        X_all[:,3] = F0_all[:,0] * F0_all[:,1]
    elif sett == 'setting5':
        # Use raw linear F0 as F
        F_all = F0_all
        X_all[:,2] = F0_all[:,0] * F0_all[:,1]
        X_all[:,3] = F0_all[:,0] + F0_all[:,1]
        X_all[:,4] = 0.2 * (F0_all[:,0]) ** 3 + 0.2 * (F0_all[:,1]) ** 3 + 0.1 * F0_all[:,0] * F0_all[:,1]
        X_all[:,5] = 0.2 * (F0_all[:,0]) ** 2 + 0.2 * (F0_all[:,1]) ** 2 + 0.1 * F0_all[:,0] * F0_all[:,1] * (F0_all[:,0] + F0_all[:,1])
    elif sett == 'setting6':
        F_all = 0.5 * F0_all + 0.2 * np.sin(F0_all) + 0.2 + (F0_all)**3
    elif sett == 'setting7':
        # Use raw linear F0 as F
        F_all = F0_all
        X_all[:,2] = 0.2 * (F0_all[:,0] + F0_all[:,1])
        X_all[:,3] = F0_all[:,0] * F0_all[:,1]



    X = torch.tensor(X_all[:n], dtype=torch.float32)
    Y = torch.tensor(Y_all[:n], dtype=torch.float32)
    F = torch.tensor(F_all[:n], dtype=torch.float32)

    X_test = torch.tensor(X_all[n:], dtype=torch.float32)
    Y_test = torch.tensor(Y_all[n:], dtype=torch.float32)
    F_test = torch.tensor(F_all[n:], dtype=torch.float32)

    return X, Y, F, X_test, Y_test, F_test

from typing import Iterable, Optional

# ---------- Feature importance utilities ----------
@torch.no_grad()
def feature_importance_weight(model: Transformer_mat) -> torch.Tensor:

    W = model.linear.weight              # [E, F]
    imp = torch.linalg.vector_norm(W, ord=2, dim=0)  # [F]
    imp = imp / (imp.sum() + 1e-12)
    return imp


def feature_importance_grad(
    model: Transformer_mat,
    data_iter: Iterable[torch.Tensor],
    device: Optional[torch.device] = None,
    max_batches: Optional[int] = None,
) -> torch.Tensor:

    was_training = model.training
    model.eval()

    # Infer device
    if device is None:
        device = next(model.parameters()).device

    F = model.linear.in_features
    running = torch.zeros(F, device=device)

    num = 0
    for num, x in enumerate(data_iter, start=1):
        if max_batches is not None and num > max_batches:
            break

        if isinstance(x, (list, tuple)):
            x = x[0]  
        # Expect [S, N, F]
        assert x.dim() == 3 and x.size(-1) == F, "Input must be [S, N, F] to match model."

        x = x.to(device).detach().requires_grad_(True)

        y = model(x)                               # [S, N, E]
        J = (y.pow(2).sum(dim=-1)).mean()          # mean ||y||^2 over S,N

        # dJ/dx
        (grad_x,) = torch.autograd.grad(J, x, retain_graph=False, create_graph=False)

        per_feature = grad_x.abs().mean(dim=(0,1))  # [F]
        running += per_feature

    if num == 0:
        raise ValueError("data_iter produced no batches.")

    imp = running / num
    imp = imp / (imp.sum() + 1e-12)  
    if was_training:
        model.train()
    return imp


@torch.no_grad()
def feature_importance_mask(
    model: Transformer_mat,
    x: torch.Tensor,
    mask_value: float = 0.0,
    agg: str = "l2"
) -> torch.Tensor:

    device = next(model.parameters()).device
    x = x.to(device)
    base = model(x)                         
    F = x.size(-1)
    scores = torch.zeros(F, device=device)

    for f in range(F):
        x_masked = x.clone()
        x_masked[..., f] = mask_value
        y = model(x_masked)
        if agg == "l2":
            delta = (y - base).pow(2).sum(dim=-1).mean()  
            scores[f] = delta
        elif agg == "cos":
            a = base.reshape(-1, base.size(-1))
            b = y.reshape(-1, y.size(-1))
            cos = torch.nn.functional.cosine_similarity(a, b, dim=-1).mean()
            scores[f] = (1 - cos).clamp_min(0)
        else:
            raise ValueError("agg must be 'l2' or 'cos'.")

    scores = scores / (scores.mean() + 1e-12)
    return scores


def data_gen_1(num_batches=5):
    for _ in range(num_batches):
        L = 4
        yield torch.randn(100, 100, d_x) * 2 * L - L



print(f"Using dataset: {dataset_nam}")



ROOT_dir = './'

max_ep = 1000

batch_size = 128

lr = 1e-4
wd = 1e-4
ss = 20
tau_lr_frac = 2
tau_fix = 1e-2
tau_tune = False
tau_lower = 1e-4
is_fix = True
id_est = MLE(K=5)

seed = 2025 + idx
torch.manual_seed(seed)
np.random.seed(seed)








## parameters
n = 10000              # Number of training samples
n_test = 1000         # Number of test samples


# data generation
print("\nLoading data...")

# Dimensionality settings
d_x, d_y, d_z = 6, 100, 2      
outdim = 10               
eps = 0.0                      


N = n + n_test                # Total number of samples



X, Y, F, X_test, Y_test, F_test = gen_simu_data(N, n, d_x, d_y, d_z, dataset_nam)


d_z0 = d_z
d_z = F.shape[1]  
if dataset_nam == 'setting3' or dataset_nam == 'setting4':
    d_z = 4



outdim = 50
middim = max(d_x, d_y)

# Build models on GPU
if arch == "transformer":
    model_x = Transformer_ma(d_x, middim=middim, dim=outdim).to(DEVICE)
    model_y = Transformer_ma(d_y, middim=middim, dim=outdim).to(DEVICE)
else:
    model_x = NonLinearNetD(d_x, middim, outdim, tau_lower=tau_lower).to(DEVICE)
    model_y = NonLinearNetD(d_y, middim, outdim, tau_lower=tau_lower).to(DEVICE)

from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X = torch.tensor(X, dtype=torch.float32, device=device)
F = torch.tensor(F, dtype=torch.float32, device=device)

dataset = TensorDataset(F, X)
loader  = DataLoader(dataset, batch_size=batch_size, shuffle=True)


arch_map = {
    'linear':    (PadToDim, LinearNet, LinearNet),       # All linear
    'nonlinear': (PadToDim, NonLinearNet, LinearNet),    # Nonlinear encoder, linear decoder
    'deep':      (PadToDim, NonLinearNetD, NonLinearNetD),  # Both encoder and decoder deep
    'transformer': (PadToDim, Transformer_ma, Transformer_ma),  # Both encoder and decoder transformer
}
Pad, Wnet, Unet = arch_map.get(arch, arch_map)  # Default to 'deep'

if is_fix:
    Fnet = PadToDim(outdim)  # Fixed projection
else:
    Fnet = Wnet(d_z, 50, outdim, tau_lower)  


for objective in ['recons', 'fact', 'disen']:

    tau_seq     = []
    loss_clip   = []
    loss_recons = []
    loss_align  = []

    indep_seq = []
    id1_seq   = []
    id2_seq   = []


    if arch == 'transformer':
        models = {
            'f': Fnet.to(device),                                   
            'w': Wnet(d_x, middim=50, dim=outdim).to(device),       
            'u': Unet(2*outdim, middim=50, dim=d_x).to(device),     
            'club': CLUB(outdim, outdim, 50).to(device),            
            'club2': CLUB(outdim, outdim, 50).to(device),            
        }
    else:
        models = {
            'f': Fnet.to(device),                                   
            'w': Wnet(d_x, 50, outdim, tau_lower).to(device),       
            'u': Unet(2*outdim, 50, d_x, tau_lower).to(device),     
            'club': CLUB(outdim, outdim, 50).to(device),            
            'club2': CLUB(outdim, outdim, 50).to(device),            
        }



    optimizers = {
        name: optim.Adam(m.parameters(), lr=lr, weight_decay=wd)
        for name, m in models.items()
        if not (name == 'f' and is_fix)  
    }
    mse_loss = nn.MSELoss()
    ce_loss  = lambda logits: Fnn.cross_entropy(logits, torch.arange(logits.size(0), device=device))

    for epoch in tqdm(range(max_ep), desc='Epoch'):
        losses = {'clip':[], 'recons':[], 'align':[]}
        tau_vals = []

        for F_batch, X_batch in loader:
            # --- Forward pass ---
            h_f = models['f'](F_batch)
            h_w = models['w'](X_batch)

            h_u = models['u'](torch.cat([h_f, h_w], dim=1))

            # Normalize for contrastive objectives
            h_w_s = Fnn.normalize(h_w, dim=-1)
            h_f_s = Fnn.normalize(h_f, dim=-1)

            # Temperature computation
            tau = ((tau_lr_frac * models['w'].logit_scale).exp() + tau_lower
                   if tau_tune else tau_fix)
            tau_vals.append(tau.item() if isinstance(tau, torch.Tensor) else tau)

            logits_fw = h_f_s @ h_w_s.T / tau
            loss_f = ce_loss(logits_fw)
            loss_w = ce_loss(logits_fw.T)
            loss_nce = (loss_f + loss_w) / 2.0

            loss_club_1 = train_club_batch(models['club'],
                                         h_f_s, h_w_s,
                                         club_steps=10,
                                         club_lr=1e-3)
            loss_club_2 = train_club_batch(models['club2'],
                                         h_w_s, h_f_s,
                                         club_steps=10,
                                         club_lr=1e-3)
            loss_club = loss_club_1 + loss_club_2

            loss_mse = mse_loss(h_u, X_batch)

            pos_sim = torch.norm(torch.matmul(h_f_s.T, h_w_s))

            # Choose training objective
            if objective == 'disen':
                clip = pos_sim.mean()
                recon = 0.5 * (
                    ce_loss((Fnn.normalize(h_u,1) @ Fnn.normalize(X_batch,1).T) / tau) +
                    ce_loss((Fnn.normalize(X_batch,1) @ Fnn.normalize(h_u,1).T) / tau)
                )
            elif objective == 'recons':
                clip = loss_club
                recon = loss_mse
            else:  # 'fact'
                clip = loss_club
                recon = 0.5 * (
                    ce_loss((Fnn.normalize(h_u,1) @ Fnn.normalize(X_batch,1).T) / tau) +
                    ce_loss((Fnn.normalize(X_batch,1) @ Fnn.normalize(h_u,1).T) / tau)
                )

            loss = clip + lam * recon

            # Optimization step
            losses['clip'].append(clip.item())
            losses['recons'].append(recon.item())
            losses['align'].append((h_f @ h_w.T).diag().mean().item())

            for name, opt in optimizers.items():
                opt.zero_grad()
            loss.backward()
            for name, opt in optimizers.items():
                opt.step()

        tau_seq.append(np.mean(tau_vals))
        loss_clip.append(np.mean(losses['clip']))
        loss_recons.append(np.mean(losses['recons']))
        loss_align.append(np.mean(losses['align']))

        if epoch % ss == 0:
            W_rep = models['w'](torch.tensor(X_test, device=device)).detach().numpy()
            F_rep = pad_to_dim(torch.tensor(F_test, device=device), outdim).detach().numpy()

            indep_seq.append([hsic(W_rep, F_rep)])
            id1_seq.append(id_est.fit_transform(W_rep))
            id2_seq.append(id_est.fit_transform(F_rep))


    results = {
    'tau_seq': tau_seq,
    'loss_clip': loss_clip,
    'loss_recons': loss_recons,
    'loss_align': loss_align,
    'indep_seq': indep_seq,
    'id1_seq': id1_seq,
    'id2_seq': id2_seq,
    'models': models
    }


    save_dir = os.path.join(ROOT_dir, f"results/{dataset_nam}")
    os.makedirs(save_dir, exist_ok=True)

    save_path = os.path.join(save_dir, f'rep_{objective}_{dataset_nam}_{n}_{d_z0}_{d_x}_{outdim}_{arch}_{idx}_{lam}.pt')
    torch.save(results, save_path)




    w_imp  = feature_importance_mask(models['w'], next(data_gen_1(1)))     # [F]


    plot_dir = os.path.join(ROOT_dir, f"figs/{dataset_nam}")
    os.makedirs(plot_dir, exist_ok=True)

    import matplotlib.pyplot as plt
    from matplotlib import gridspec

    # Combined figure: 2 rows × 4 columns
    fig = plt.figure(figsize=(20, 4))
    gs  = gridspec.GridSpec(1, 4, wspace=0.3, hspace=0.4)

    # 1) Line plots in row 0 (no τ)
    metrics = {
        'MI(Z;C)':   loss_clip,
        # 'similarity':    loss_align,
        'HS-indep': indep_seq,
        'MI(C,Z;input)':   loss_recons,
    }
    for i, (title, data) in enumerate(metrics.items()):
        ax = fig.add_subplot(gs[i])
        if title == 'MLE-ID':
            y, x = data
            ax.plot(x, y, 'b+--', label='h(X)')
            ax.axhline(d_z0,     color='r', linestyle='--', label=r'$d_z$')
            ax.axhline(d_x-d_z, color='b', linestyle='--', label=r'$d_x-d_z$')
            ax.legend(loc='best')
            ax.set_ylim((0, 8))
        elif title == 'HS-indep':
            ax.plot(np.arange(0,ss*len(indep_seq),ss), indep_seq, 'ro--')
        else:
            ax.plot(data)
        ax.set_title(title)
        ax.set_xlabel('epochs')

    ax = fig.add_subplot(gs[3])
    vals = w_imp.cpu().numpy()
    colors = ['orange' if i < d_z else 'tab:blue' for i in range(d_x)]
    ax.bar(np.arange(1, d_x + 1), vals, color=colors)
    ax.set_title('Feature Importance (mask-based)')
    ax.set_xlabel('Feature Index')
    ax.set_ylabel('Importance')
    ax.set_xticks(np.arange(1, d_x+1))


    # Save
    plt.tight_layout()

    outfile = os.path.join(plot_dir, f'loss_{objective}_{dataset_nam}_{n}_{d_z}_{d_x}_{outdim}_{arch}_{idx}_{lam}.png')
    plt.savefig(outfile, bbox_inches='tight')


    # Combined figure: 2 rows × 4 columns
    fig = plt.figure(figsize=(5, 4))
    gs  = gridspec.GridSpec(1, 1)

    ax = fig.add_subplot(gs[0])
    vals = w_imp.cpu().numpy()
    colors = ['orange' if i < d_z else 'tab:blue' for i in range(d_x)]
    ax.bar(np.arange(1, d_x + 1), vals, color=colors)
    ax.set_title('Feature Importance (mask-based)')
    ax.set_xlabel('Feature Index')
    ax.set_ylabel('Importance')
    ax.set_xticks(np.arange(1, d_x+1))


    # Save
    plt.tight_layout()
    outfile = os.path.join(plot_dir, f'import_{objective}_{dataset_nam}_{n}_{d_z0}_{d_x}_{outdim}_{arch}_{idx}_{lam}.png')
    plt.savefig(outfile, bbox_inches='tight')
