import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

import skdim
from skdim.id import MLE, KNN, lPCA, FisherS

from utils import *

ROOT_dir= "./"



DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[run.py] Using device: {DEVICE}")

def load_on_device(path):
    return torch.load(path, map_location=DEVICE, weights_only=False)


import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from pathlib import Path
import json


parser = argparse.ArgumentParser(description="Run experiment with chosen dataset")
parser.add_argument("--dataset_nam", type=str, required=True,
                    help="setting x")
parser.add_argument("--arch", type=str, required=True,
                    help="deep or transformer")
parser.add_argument("--idx", type=int, required=True,
                    help="index")
parser.add_argument("--lam", type=float, required=True,
                    help="parameter")
args = parser.parse_args()

dataset_nam = args.dataset_nam
idx = args.idx
lam = args.lam
arch = args.arch




from typing import Iterable, Optional

@torch.no_grad()
def feature_importance_weight(model: Transformer_mat) -> torch.Tensor:
    W = model.linear.weight              # [E, F]
    imp = torch.linalg.vector_norm(W, ord=2, dim=0)  # [F]
    imp = imp / (imp.sum() + 1e-12)
    return imp


def feature_importance_grad(
    model: Transformer_mat,
    data_iter: Iterable[torch.Tensor],
    device: Optional[torch.device] = None,
    max_batches: Optional[int] = None,
) -> torch.Tensor:
    
    was_training = model.training
    model.eval()

    # Infer device
    if device is None:
        device = next(model.parameters()).device

    F = model.linear.in_features
    running = torch.zeros(F, device=device)

    num = 0
    for num, x in enumerate(data_iter, start=1):
        if max_batches is not None and num > max_batches:
            break

        if isinstance(x, (list, tuple)):
            x = x[0]  
        assert x.dim() == 3 and x.size(-1) == F, "Input must be [S, N, F] to match model."

        x = x.to(device).detach().requires_grad_(True)

        y = model(x)                               # [S, N, E]
        J = (y.pow(2).sum(dim=-1)).mean()          

        # dJ/dx
        (grad_x,) = torch.autograd.grad(J, x, retain_graph=False, create_graph=False)

        per_feature = grad_x.abs().mean(dim=(0,1))  # [F]
        running += per_feature

    if num == 0:
        raise ValueError("data_iter produced no batches.")

    imp = running / num
    imp = imp / (imp.sum() + 1e-12)  # normalize
    if was_training:
        model.train()
    return imp


@torch.no_grad()
def feature_importance_mask(
    model: Transformer_mat,
    x: torch.Tensor,
    mask_value: float = 0.0,
    agg: str = "l2"
) -> torch.Tensor:
    
    device = next(model.parameters()).device
    x = x.to(device)
    base = model(x)                         # [S, N, E]
    F = x.size(-1)
    scores = torch.zeros(F, device=device)

    for f in range(F):
        x_masked = x.clone()
        x_masked[..., f] = mask_value
        y = model(x_masked)
        if agg == "l2":
            delta = (y - base).pow(2).sum(dim=-1).mean()  # mean over S,N
            scores[f] = delta
        elif agg == "cos":
            # flatten S,N,E -> (S*N, E)
            a = base.reshape(-1, base.size(-1))
            b = y.reshape(-1, y.size(-1))
            cos = torch.nn.functional.cosine_similarity(a, b, dim=-1).mean()
            scores[f] = (1 - cos).clamp_min(0)
        else:
            raise ValueError("agg must be 'l2' or 'cos'.")

    scores = scores / (scores.mean() + 1e-12)
    return scores


def data_gen_1(num_batches=5):
    for _ in range(num_batches):
        yield torch.randn(100, 100, d_x)



print(f"Using dataset: {dataset_nam}")



ROOT_dir = './'

max_ep = 1000

batch_size = 128

lr = 1e-4
wd = 1e-4
ss = 20
tau_lr_frac = 2
tau_fix = 1e-2
tau_tune = False
tau_lower = 1e-4
is_fix = True
id_est = MLE(K=5)

seed = 2025 + idx
torch.manual_seed(seed)
np.random.seed(seed)








## parameters
n = 10000              # Number of training samples
n_test = 1000         # Number of test samples


# data generation
print("\nLoading data...")

# Dimensionality settings
d_x, d_y, d_z = 6, 100, 2      
outdim = 10               
eps = 0.0                      


N = n + n_test                # Total number of samples



X, Y, F, X_test, Y_test, F_test = gen_simu_data(N, n, d_x, d_y, d_z, dataset_nam)


d_z0 = d_z
d_z = F.shape[1]  
if dataset_nam == 'setting3' or dataset_nam == 'setting4':
    d_z = 4



outdim = 50
middim = max(d_x, d_y)

# Build models on GPU
if arch == "transformer":
    model_x = Transformer_ma(d_x, outdim).to(DEVICE)
    model_y = Transformer_ma(d_y, outdim).to(DEVICE)
else:
    model_x = NonLinearNetD(d_x, middim, outdim, tau_lower=tau_lower).to(DEVICE)
    model_y = NonLinearNetD(d_y, middim, outdim, tau_lower=tau_lower).to(DEVICE)

from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader

tau_seq     = []
loss_clip   = []
loss_recons = []
loss_align  = []

indep_seq = []
id1_seq   = []
id2_seq   = []

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X = torch.tensor(X, dtype=torch.float32, device=device)
F = torch.tensor(F, dtype=torch.float32, device=device)

dataset = TensorDataset(F, X)
loader  = DataLoader(dataset, batch_size=batch_size, shuffle=True)


arch_map = {
    'linear':    (PadToDim, LinearNet, LinearNet),       
    'nonlinear': (PadToDim, NonLinearNet, LinearNet),    
    'deep':      (PadToDim, NonLinearNetD, NonLinearNetD),  
    'transformer': (PadToDim, Transformer_ma, Transformer_ma),  
}
Pad, Wnet, Unet = arch_map.get(arch, arch_map)  # Default to 'deep'

if is_fix:
    Fnet = PadToDim(outdim)  # Fixed projection
else:
    Fnet = Wnet(d_z, 50, outdim, tau_lower)  


for objective in ['recons', 'fact', 'disen']:
    
    save_path = f'./results/rep_{objective}_{dataset_nam}_{n}_{d_z0}_{d_x}_{outdim}_{arch}_{idx}_{lam}.pt'
    results = torch.load(save_path, weights_only=False)
    tau_seq     = results['tau_seq']
    loss_clip   = results['loss_clip']
    loss_recons = results['loss_recons']
    loss_align  = results['loss_align']
    indep_seq   = results['indep_seq']
    id1_seq     = results['id1_seq']
    id2_seq     = results['id2_seq']
    models      = results['models']

    w_imp  = feature_importance_mask(models['w'], next(data_gen_1(1)))     # [F]

    import matplotlib.pyplot as plt
    from matplotlib import gridspec

    fig = plt.figure(figsize=(20, 4))
    gs  = gridspec.GridSpec(1, 4, wspace=0.3, hspace=0.4)

    metrics = {
        'MI(Z;C)':   loss_clip,
        # 'similarity':    loss_align,
        'HS-indep': indep_seq,
        'MI(C,Z;input)':   loss_recons,
    }
    for i, (title, data) in enumerate(metrics.items()):
        ax = fig.add_subplot(gs[i])
        if title == 'MLE-ID':
            y, x = data
            ax.plot(x, y, 'b+--', label='h(X)')
            ax.axhline(d_z0,     color='r', linestyle='--', label=r'$d_z$')
            ax.axhline(d_x-d_z, color='b', linestyle='--', label=r'$d_x-d_z$')
            ax.legend(loc='best')
            ax.set_ylim((0, 8))
        elif title == 'HS-indep':
            ax.plot(np.arange(0,ss*len(indep_seq),ss), indep_seq, 'ro--')
        else:
            ax.plot(data)
        ax.set_title(title)
        ax.set_xlabel('epochs')

    ax = fig.add_subplot(gs[3])
    vals = w_imp.cpu().numpy()
    colors = ['orange' if i < d_z else 'tab:blue' for i in range(d_x)]
    ax.bar(np.arange(1, d_x + 1), vals, color=colors)
    ax.set_title('Feature Importance (mask-based)')
    ax.set_xlabel('Feature Index')
    ax.set_ylabel('Importance')
    ax.set_xticks(np.arange(1, d_x+1))


    # Save
    plt.tight_layout()

    plot_dir = os.path.join(ROOT_dir, f"figs/{dataset_nam}")
    os.makedirs(save_dir, exist_ok=True)

    outfile = os.path.join(plot_dir, f'loss_{objective}_{dataset_nam}_{n}_{d_z}_{d_x}_{outdim}_{arch}_{idx}_{lam}.png')
    plt.savefig(outfile, bbox_inches='tight')


    fig = plt.figure(figsize=(5, 4))
    gs  = gridspec.GridSpec(1, 1)

    ax = fig.add_subplot(gs[0])
    vals = w_imp.cpu().numpy()
    colors = ['orange' if i < d_z else 'tab:blue' for i in range(d_x)]
    ax.bar(np.arange(1, d_x + 1), vals, color=colors)
    ax.set_title('Feature Importance (mask-based)')
    ax.set_xlabel('Feature Index')
    ax.set_ylabel('Importance')
    ax.set_xticks(np.arange(1, d_x+1))


    # Save
    plt.tight_layout()
    outfile = os.path.join(plot_dir, f'import_{objective}_{dataset_nam}_{n}_{d_z0}_{d_x}_{outdim}_{arch}_{idx}_{lam}.png')
    plt.savefig(outfile, bbox_inches='tight')



