import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import os
from argparse import Namespace

# Custom project imports
from exp.experiment_explanation_v2 import Explanation_Experiment
from models.AmortizedFCN import AmortizedExplainerFCN
from data_provider.explanation_data_loader import AmortizedExplanationLoader_V2, padding_mask
from torch.optim.lr_scheduler import ReduceLROnPlateau
from models.Fidelity_gate_net import FidelityGateNet

def amortized_collate_fn(data, max_len=None):
    """
    Custom collate function to handle batches of variable-length raw time series
    while keeping meta-targets and labels synchronized.
    """
    batch_size = len(data)
    amortized_tuple, target_tuple, raw_tuple, labels_tuple = zip(*data)

    # for i in range(len(amortized_tuple)):
    #     if amortized_tuple[i].shape[0]!=119:
    #         print(amortized_tuple[i].shape, target_tuple[i].shape, raw_tuple[i].shape, labels_tuple[i].shape)
    # exit()

    # 1. Determine the maximum sequence length in this batch for padding
    lengths = [X.shape[0] for X in raw_tuple]
    if max_len is None:
        max_len = max(lengths)
    
    # 2. Get feature dimensions
    feat_dim = raw_tuple[0].shape[-1]
    
    # 3. Stack meta-inputs and targets (assumed fixed length from loader/Stage 1)
    # X_AMORTIZED = torch.stack(amortized_tuple, dim=0)
    X_AMORTIZED = None 
    Y_TARGET = torch.stack(target_tuple, dim=0)
    
    # 4. Pad raw data to max_len
    X_RAW = torch.zeros(batch_size, max_len, feat_dim)
    for i in range(batch_size):
        end = min(lengths[i], max_len)
        X_RAW[i, :end, :] = raw_tuple[i][:end, :]

    # 5. Stack Labels and generate a boolean Padding Mask (1=Keep, 0=Pad)
    labels = torch.stack(labels_tuple, dim=0)
    p_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16), max_len=max_len)

    return X_AMORTIZED, Y_TARGET, X_RAW, labels, p_masks

# --- Fidelity Loss: Penalizes explainers that produce unfaithful maps ---
class FidelityConfidence_Mask_Robust_Loss(nn.Module):
    def __init__(self, target_model, fidelity_gates):
        """
        target_model: The frozen classification model being explained.
        fidelity_gates: Dict of frozen FidelityGateNets for different k sparsity levels.
        """
        super(FidelityConfidence_Mask_Robust_Loss, self).__init__()
        self.target_model = target_model
        self.fidelity_gates = fidelity_gates
        self.softmax = nn.Softmax(dim=-1)
        
        # Ensure all evaluators are frozen to prevent unintended updates
        self.target_model.eval()
        for param in self.target_model.parameters():
            param.requires_grad = False
        for gate in self.fidelity_gates.values():
            gate.eval()
            for param in gate.parameters():
                param.requires_grad = False

    def forward(self, raw_x, pad_mask, true_labels, predicted_attr, noisy_predicted_attr, k_list=[80, 85, 90]):
        batch_size = raw_x.size(0)
        total_faithfulness_error = 0.0
        total_mask_robust_error = 0.0
        
        if true_labels.dim() == 1:
            true_labels = true_labels.unsqueeze(1)
        true_labels = true_labels.long()

        for k in k_list:
            gate = self.fidelity_gates[k]
            
            # 1. Convert continuous attribution into a soft binary mask via the learned gate
            soft_mask = gate(predicted_attr)
            noisy_soft_mask = gate(noisy_predicted_attr)
            
            # 2. Apply soft removal: Keep features the model deems important
            masked_x = raw_x * soft_mask
            
            # 3. Get Model Confidence for the true class on masked data
            logits = self.target_model(masked_x, pad_mask, None, None)
            if isinstance(logits, tuple): logits = logits[0]
            probs = self.softmax(logits)
            
            # 4. Extraction of confidence (We want this high for bottom-up removal)
            true_class_conf = torch.gather(probs, 1, true_labels)
            total_faithfulness_error += torch.mean(1.0 - true_class_conf)

            total_mask_robust_error += torch.nn.MSELoss()(soft_mask, noisy_soft_mask)

        return total_faithfulness_error / len(k_list), total_mask_robust_error / len(k_list)

class EarlyStopping:
    def __init__(self, patience=15, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def get_args():
    parser = argparse.ArgumentParser(description="Train Amortized Explainer")
    parser.add_argument("--dataset", type=str, default="UWaveGestureLibrary")
    parser.add_argument("--data_root", type=str, default="./data/UEA_multivariate")
    parser.add_argument("--model", type=str, default="DNN")
    parser.add_argument("--dnn_type", type=str, default="FCN")
    parser.add_argument("--enc_in", type=int, required=True)
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--train_epochs", type=int, default=100)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--lambda_fid", type=float, default=None)
    parser.add_argument("--lambda_rob", type=float, default=None)
    
    args = parser.parse_args()
    args.root_path = os.path.join(args.data_root, args.dataset)
    return args

def get_target_model_defaults(new_args):
    """Sets up the configuration for loading the pre-trained classification model."""
    args = Namespace(
        data='UEA', data_root="./data/UEA_multivariate",
        dataset=new_args.dataset, model=new_args.model,
        dnn_type=new_args.dnn_type, seed=new_args.seed,
        lambda_reg=0.1, lambda_div=0.1, epsilon=1.0, num_shapelet=10,
        gating_value=None, pos_weight=False, sbm_cls='linear',
        distance_func='euclidean', beta_schedule='constant',
        memory_efficient=False, lr=new_args.lr, lr_decay=False,
        gradient_accumulation_steps=1, gradient_clip=0,
        batch_size=new_args.batch_size, train_epochs=new_args.train_epochs, num_workers=0,
        patience=50, multi_gpu=False, test_only=True, amp=True,
        task_name='explanation', model_id='test', embed='timeF', freq='h',
        dropout=0.0, activation='gelu', output_attention=False, is_training=False, 
        d_model=512, e_layers=2, factor= 1, n_heads=8, d_ff=2048 # default config for PatchTST
    )
    args.root_path = f"{args.data_root}/{args.dataset}"
    return args

def train_explainer():
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --- STEP 1: LOAD TARGET ORACLE MODEL ---
    tm_args = get_target_model_defaults(args)
    print(tm_args)
    exp = Explanation_Experiment(args=tm_args)
    target_dnn = exp.model.to(device)
    target_dnn.eval()
    for param in target_dnn.parameters():
        param.requires_grad = False
    
    # train_ds = AmortizedExplanationLoader_V2(args, args.root_path, flag='TRAIN', target_set='train')
    # val_ds = AmortizedExplanationLoader_V2(args, args.root_path, flag='TEST', target_set='test')
    # exit()

    # --- STEP 2: LOAD FIDELITY GATES ---
    k_to_train_on = [80, 85, 90, 95]
    fidelity_gates = {}
    gate_dir = f"./fidelity_gate_checkpoints/{tm_args.dataset}/{tm_args.model}/{tm_args.dnn_type}"
    for k in k_to_train_on:
        gate_path = os.path.join(gate_dir, f"{args.dataset}-{args.model}-{args.dnn_type}-k{k}-seed{tm_args.seed}.pt")
        gate = FidelityGateNet(c_in=args.enc_in).to(device)
        gate.load_state_dict(torch.load(gate_path))
        fidelity_gates[k] = gate
    
    # --- STEP 3: INITIALIZE DATA LOADERS ---
    train_ds = AmortizedExplanationLoader_V2(args, args.root_path, flag='TRAIN', target_set='train')
    val_ds = AmortizedExplanationLoader_V2(args, args.root_path, flag='TEST', target_set='test')
    # exit()
    max_seq_len = max(train_ds.max_seq_len, val_ds.max_seq_len)
    # print(max_seq_len, train_ds.max_seq_len, val_ds.max_seq_len)
    # exit()
    
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, 
                              collate_fn=lambda x: amortized_collate_fn(x, max_len=max_seq_len))
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, 
                            collate_fn=lambda x: amortized_collate_fn(x, max_len=max_seq_len))
    
    # --- STEP 4: BUILD EXPLAINER MODEL ---
    # c_in = enc_in * 2 (Channel 1: Raw, Channel 2: Dynamic InputXGradient)
    model = AmortizedExplainerFCN(c_in=args.enc_in * 2, c_out=args.enc_in, hidden_dim=args.hidden_dim).to(device)
    
    # --- STEP 5: SETUP OPTIMIZER & LOSS ---
    distillation_criterion = nn.MSELoss() 
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    early_stopping = EarlyStopping(patience=100)
    fidelity_mask_robust_criterion = FidelityConfidence_Mask_Robust_Loss(target_dnn, fidelity_gates).to(device)

    os.makedirs(f"./explainer_checkpoints/{args.dataset}", exist_ok=True)
    best_model_path = os.path.join(f"./explainer_checkpoints/{args.dataset}", f"Distill_only_Amortized_{args.dataset}_{args.model}_{args.dnn_type}_v2_test.pth")
    best_val_loss = float('inf')

    # --- STEP 6: TRAINING LOOP ---
    # print(len(train_loader), len(val_loader))
    # exit()
    for epoch in range(args.train_epochs):
        model.train()
        train_loss = 0
        for i, (batch_x, batch_y, raw_x, labels, padding_mask) in enumerate(train_loader):
            # batch_x is [B, T, C] raw data from your V2 loader
            raw_x = raw_x.to(device).float()
            batch_y = batch_y.to(device).float()
            labels = labels.to(device).long()
            padding_mask = padding_mask.to(device).float()

            # --- STEP 6A: DYNAMIC INPUTXGRADIENT CALCULATION ---
            # 1. Enable gradients for the raw input tensor
            raw_x.requires_grad = True
            # 2. Forward pass through frozen classification model
            target_logits = target_dnn(raw_x, padding_mask, None, None)
            if isinstance(target_logits, tuple): target_logits = target_logits[0]
            # 3. Target the correct class score
            score = torch.gather(target_logits, 1, labels.view(-1, 1)).sum()
            # 4. Backward pass to compute grads w.r.t input
            target_dnn.zero_grad()
            score.backward(retain_graph=True)
            # 5. Extract InputXGradient: Input * Gradient (Detached from classification graph)
            dynamic_ixg = (raw_x * raw_x.grad).detach()
            # 6. Cleanup: Reset gradients and disable tracking for raw_x
            raw_x.grad.zero_()
            raw_x.requires_grad = False

            # --- STEP 6B: META-INPUT CONCATENATION [B, T, 2C] ---
            # Channel 1: Raw Data, Channel 2: InputXGradient
            enriched_input = torch.cat([raw_x.detach(), dynamic_ixg], dim=-1)
            # print(enriched_input.shape, max_seq_len, raw_x.shape)
            # exit()

            # ==========================================
            # PASS 2: NOISY DATA ATTR CALCULATION (ROBUSTNESS)
            # ==========================================
            # 1. Create Noisy Data: 0.2 * standard deviation of the current instance
            instance_std = torch.std(raw_x, dim=(1, 2), keepdim=True)
            noise = torch.randn_like(raw_x) * (instance_std * 0.2)
            noisy_raw_x = raw_x + noise

            # 2. Repeat Step 1-6 for NOISY data
            noisy_raw_x.requires_grad = True
            noisy_target_logits = target_dnn(noisy_raw_x, padding_mask, None, None)
            if isinstance(noisy_target_logits, tuple): noisy_target_logits = noisy_target_logits[0]
            
            noisy_score = torch.gather(noisy_target_logits, 1, labels.view(-1, 1)).sum()
            target_dnn.zero_grad()
            noisy_score.backward(retain_graph=True)
            
            dynamic_ixg_noisy = (noisy_raw_x * noisy_raw_x.grad).detach()
            enriched_noisy_input = torch.cat([noisy_raw_x.detach(), dynamic_ixg_noisy], dim=-1)

            # Cleanup
            noisy_raw_x.grad.zero_()
            noisy_raw_x.requires_grad = False

            # --- STEP 6C: EXPLAINER OPTIMIZATION ---
            optimizer.zero_grad()
            pred_map = model(enriched_input)
            noisy_pred_map = model(enriched_noisy_input)
            
            # Combine MSE (Similarity to Ensemble) and Fidelity (Model Logic alignment)
            mse_loss = distillation_criterion(pred_map, batch_y)
            loss_fid, loss_mask_rob = fidelity_mask_robust_criterion(raw_x.detach(), padding_mask, labels, pred_map, noisy_pred_map, k_list=k_to_train_on)

            # NEW: Robustness Loss (MSE between Clean Prediction and Noisy Prediction)
            loss_rob = distillation_criterion(pred_map, noisy_pred_map)
            
            loss = mse_loss + args.lambda_fid * loss_fid + args.lambda_rob * (loss_mask_rob + loss_rob)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # --- STEP 7: VALIDATION PHASE ---
        model.eval()
        mse_val_loss, val_fid, val_rob = 0, 0, 0
        with torch.no_grad():
            for batch_x, batch_y, raw_x, labels, padding_mask in val_loader:
                # Need enable_grad for the dynamic attribution during validation
                with torch.enable_grad():
                    raw_x = raw_x.to(device).float().requires_grad_(True)
                    labels = labels.to(device).long()
                    padding_mask = padding_mask.to(device).float()
                    
                    # Same logic as Step 6A to build the 2-channel meta-input
                    target_logits = target_dnn(raw_x, padding_mask, None, None)
                    if isinstance(target_logits, tuple): target_logits = target_logits[0]
                    score = torch.gather(target_logits, 1, labels.view(-1, 1)).sum()
                    grads = torch.autograd.grad(score, raw_x)[0].detach()
                    enriched_input = torch.cat([raw_x.detach(), raw_x.detach() * grads], dim=-1)
                    raw_x.requires_grad = False
                
                    # ==========================================
                    # PASS 2: NOISY DATA ATTR CALCULATION (ROBUSTNESS)
                    # ==========================================
                    # 1. Create Noisy Data: 0.2 * standard deviation of the current instance
                    instance_std = torch.std(raw_x, dim=(1, 2), keepdim=True)
                    noise = torch.randn_like(raw_x) * (instance_std * 0.2)
                    noisy_raw_x = raw_x + noise

                    # 2. Repeat Step 1-6 for NOISY data
                    noisy_raw_x.requires_grad = True
                    noisy_target_logits = target_dnn(noisy_raw_x, padding_mask, None, None)
                    if isinstance(noisy_target_logits, tuple): noisy_target_logits = noisy_target_logits[0]
                    
                    noisy_score = torch.gather(noisy_target_logits, 1, labels.view(-1, 1)).sum()
                    target_dnn.zero_grad()
                    noisy_score.backward(retain_graph=True)
                    
                    dynamic_ixg_noisy = (noisy_raw_x * noisy_raw_x.grad).detach()
                    enriched_noisy_input = torch.cat([noisy_raw_x.detach(), dynamic_ixg_noisy], dim=-1)

                    # Cleanup
                    noisy_raw_x.grad.zero_()
                    noisy_raw_x.requires_grad = False
                
                # Explainer inference
                with torch.no_grad():
                    pred_map = model(enriched_input)
                    noisy_pred_map = model(enriched_noisy_input)
                    # val_rob += distillation_criterion(pred_map, noisy_pred_map)
                    mse_val_loss += distillation_criterion(pred_map, batch_y.to(device)).item()
                    val_fid = val_fid + fidelity_mask_robust_criterion(raw_x.detach(), padding_mask, labels, pred_map, noisy_pred_map, k_list=k_to_train_on)[0].item()
                # val_rob += robustness_criterion(raw_x.detach(), padding_mask, labels, pred_map, model, target_dnn).item()

        # Step Schedulers and Save Best Model
        avg_val = (mse_val_loss / len(val_loader)) + args.lambda_fid * (val_fid / len(val_loader)) 
        # + args.lambda_rob * (val_rob / len(val_loader))
        scheduler.step(avg_val)
        early_stopping(avg_val)

        if epoch % 10 == 0:
            print(f"Epoch {epoch+1} | Train Loss: {train_loss/len(train_loader):.6f} | Val Loss: {avg_val:.6f}")
        
        if avg_val < best_val_loss:
            best_val_loss = avg_val
            torch.save(model.state_dict(), best_model_path)
            print(f"New Best Val: {best_val_loss:.6f} -- Model Saved.")
        
        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch}.")
            break

    # --- STEP 8: FINAL INFERENCE & SAVE ---
    model.load_state_dict(torch.load(best_model_path))
    model.eval()
    method = 'Distill_only_amortized_attr_v2'
    all_amortized_maps = []
    with torch.no_grad():
        for _, _, raw_x, labels, padding_mask in val_loader:
            with torch.enable_grad():
                raw_x = raw_x.to(device).float().requires_grad_(True)
                labels = labels.to(device).long()
                target_logits = target_dnn(raw_x, padding_mask.to(device), None, None)
                if isinstance(target_logits, tuple): target_logits = target_logits[0]
                score = torch.gather(target_logits, 1, labels.view(-1, 1)).sum()
                grads = torch.autograd.grad(score, raw_x)[0].detach()
                enriched_input = torch.cat([raw_x.detach(), raw_x.detach() * grads], dim=-1)
            all_amortized_maps.append(model(enriched_input).cpu().numpy())
    
    final_maps = np.concatenate(all_amortized_maps, axis=0)
    save_dir = os.path.join(f"./explanations/{args.dataset}", args.model, args.dnn_type)
    filename = f"{args.dataset}-{args.seed}-{args.model}-{args.dnn_type}-test_{method}.npy"
    save_path = os.path.join(save_dir, filename)
    # save_path = os.path.join(f"./explanations/{args.dataset}", args.model, args.dnn_type, f"test_amortized_attr_v2.npy")
    np.save(save_path, {'attributions': final_maps})
    print(f"Amortized Explanation saved to {save_path}")

    # --- STEP 9: GENERATE NOISY AMORTIZED MAPS FOR ROBUSTNESS ---
    print(f"--- STEP 9: Generating Noisy Amortized Maps (Noise std = 0.2 * data_std) ---")
    all_noisy_amortized_maps = []
    with torch.no_grad():
        for _, _, raw_x, labels, padding_mask in val_loader:
            raw_x = raw_x.to(device).float()
            labels = labels.to(device).long()
            padding_mask = padding_mask.to(device).float()

            # 1. Calculate per-instance noise level: 0.2 * std of the raw data [B, 1, 1]
            instance_std = torch.std(raw_x, dim=(1, 2), keepdim=True)
            noise = torch.randn_like(raw_x) * (instance_std * 0.2)
            noisy_raw_x = raw_x + noise

            # 2. Recalculate dynamic attribution for the NOISY data
            with torch.enable_grad():
                noisy_raw_x.requires_grad = True
                target_logits = target_dnn(noisy_raw_x, padding_mask, None, None)
                if isinstance(target_logits, tuple): target_logits = target_logits[0]
                score = torch.gather(target_logits, 1, labels.view(-1, 1)).sum()
                grads = torch.autograd.grad(score, noisy_raw_x)[0].detach()
                
                # Construct meta-input from noisy components
                enriched_noisy_input = torch.cat([noisy_raw_x.detach(), noisy_raw_x.detach() * grads], dim=-1)
            
            # 3. Explainer inference on the noisy enriched input
            all_noisy_amortized_maps.append(model(enriched_noisy_input).cpu().numpy())

    # 4. Concatenate results and save with suffix
    final_noisy_maps = np.concatenate(all_noisy_amortized_maps, axis=0)
    noisy_save_path = save_path.replace('.npy', '_noisy.npy')
    np.save(noisy_save_path, {'attributions': final_noisy_maps})
    print(f"Noisy Amortized Explanation saved to {noisy_save_path}")

if __name__ == "__main__":
    train_explainer()
