import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import os
import time
from argparse import Namespace
from einops import rearrange
import matplotlib.pyplot as plt  # Added for visualization


# Custom project imports
from exp.experiment_explanation_v2 import Explanation_Experiment
from models.AmortizedFCN import AmortizedExplainerFCN
from data_provider.explanation_data_loader import AmortizedExplanationLoader_V2, padding_mask
from torch.optim.lr_scheduler import ReduceLROnPlateau
from models.Fidelity_gate_net import FidelityGateNet

def amortized_collate_fn(data, max_len=None):
    """
    Custom collate function to handle batches of variable-length raw time series
    while keeping meta-targets and labels synchronized.
    """
    batch_size = len(data)
    amortized_tuple, target_tuple, raw_tuple, labels_tuple = zip(*data)

    # for i in range(len(amortized_tuple)):
    #     if amortized_tuple[i].shape[0]!=119:
    #         print(amortized_tuple[i].shape, target_tuple[i].shape, raw_tuple[i].shape, labels_tuple[i].shape)
    # exit()

    # 1. Determine the maximum sequence length in this batch for padding
    lengths = [X.shape[0] for X in raw_tuple]
    if max_len is None:
        max_len = max(lengths)
    
    # 2. Get feature dimensions
    feat_dim = raw_tuple[0].shape[-1]
    
    # 3. Stack meta-inputs and targets (assumed fixed length from loader/Stage 1)
    # X_AMORTIZED = torch.stack(amortized_tuple, dim=0)
    X_AMORTIZED = None 
    Y_TARGET = torch.stack(target_tuple, dim=0)
    
    # 4. Pad raw data to max_len
    X_RAW = torch.zeros(batch_size, max_len, feat_dim)
    for i in range(batch_size):
        end = min(lengths[i], max_len)
        X_RAW[i, :end, :] = raw_tuple[i][:end, :]

    # 5. Stack Labels and generate a boolean Padding Mask (1=Keep, 0=Pad)
    labels = torch.stack(labels_tuple, dim=0)
    p_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16), max_len=max_len)

    return X_AMORTIZED, Y_TARGET, X_RAW, labels, p_masks



# --- Fidelity Loss: Penalizes explainers that produce unfaithful maps ---
class FidelityConfidence_Mask_Robust_Loss(nn.Module):
    def __init__(self, target_model, fidelity_gates):
        """
        target_model: The frozen classification model being explained.
        fidelity_gates: Dict of frozen FidelityGateNets for different k sparsity levels.
        """
        super(FidelityConfidence_Mask_Robust_Loss, self).__init__()
        self.target_model = target_model
        self.fidelity_gates = fidelity_gates
        self.softmax = nn.Softmax(dim=-1)
        
        # Ensure all evaluators are frozen to prevent unintended updates
        self.target_model.eval()
        for param in self.target_model.parameters():
            param.requires_grad = False
        for gate in self.fidelity_gates.values():
            gate.eval()
            for param in gate.parameters():
                param.requires_grad = False

    def forward(self, raw_x, pad_mask, true_labels, predicted_attr, noisy_predicted_attr, k_list=[80, 85, 90]):
        batch_size = raw_x.size(0)
        total_faithfulness_error = 0.0
        total_mask_robust_error = 0.0
        
        if true_labels.dim() == 1:
            true_labels = true_labels.unsqueeze(1)
        true_labels = true_labels.long()

        for k in k_list:
            gate = self.fidelity_gates[k]
            
            # 1. Convert continuous attribution into a soft binary mask via the learned gate
            soft_mask = gate(predicted_attr)
            noisy_soft_mask = gate(noisy_predicted_attr)
            
            # 2. Apply soft removal: Keep features the model deems important
            masked_x = raw_x * soft_mask
            
            # 3. Get Model Confidence for the true class on masked data
            logits = self.target_model(masked_x, pad_mask, None, None)
            if isinstance(logits, tuple): logits = logits[0]
            probs = self.softmax(logits)
            
            # 4. Extraction of confidence (We want this high for bottom-up removal)
            true_class_conf = torch.gather(probs, 1, true_labels)
            total_faithfulness_error += torch.mean(1.0 - true_class_conf)

            total_mask_robust_error += torch.nn.MSELoss()(soft_mask, noisy_soft_mask)

        return total_faithfulness_error / len(k_list), total_mask_robust_error / len(k_list)

# --- NEW FUNCTION: Stratified Segment Merging Algorithm ---
def faithfulness_preserving_segmentation(raw_attribution, step=0.1):
    """
    Converts point-wise attribution maps to segments using a stratified set-difference strategy.
    
    Algorithm:
    1. Iterate ratios from 0.1 to 1.0 (e.g., 0.1, 0.2, ...).
    2. Define current top_k set and previous top_prev_k set.
    3. Identify the difference set (top_r = top_k \ top_prev_k).
    4. Find segments within this specific difference set (top_r).
    5. Assign scores to these segments based on the average of points inside them.
    
    Args:
        raw_attribution (torch.Tensor): [Batch, Seq_Len, Feat]
        step (float): Step size for ratios (default 0.1).
        
    Returns:
        tuple: (final_segmented_map, segments_dict)
               - final_segmented_map: Tensor [Batch, Seq_Len, Feat]
               - segments_dict: Dict {b: {f: [(start, end, score), ...]}}
    """
    # Ensure input is a tensor
    if not isinstance(raw_attribution, torch.Tensor):
        raw_attribution = torch.tensor(raw_attribution)
        
    # Handle single instance vs batch
    if raw_attribution.dim() == 2:
        raw_attribution = raw_attribution.unsqueeze(0) # [1, Seq, Feat]
    
    batch_size, seq_len, n_feats = raw_attribution.shape
    final_segmented_map = torch.zeros_like(raw_attribution)
    
    # Dictionary to store segment metadata for visualization
    # Structure: segments_dict[batch_idx][feature_idx] = list of (start, end, score)
    segments_dict = {}

    # Helper function to find contiguous segments in a boolean mask
    def find_segments(mask):
        segments = []
        if not mask.any(): return segments
        
        indices = torch.where(mask)[0]
        if len(indices) == 0: return segments
        
        start = indices[0].item()
        prev = start
        
        for i in indices[1:]:
            curr = i.item()
            if curr > prev + 1: # Gap found, close current segment
                segments.append((start, prev))
                start = curr
            prev = curr
        segments.append((start, prev)) # Close last segment
        return segments

    # Process each instance in the batch independently
    for b in range(batch_size):
        segments_dict[b] = {}
        
        for f in range(n_feats):
            segments_dict[b][f] = []
            attr = raw_attribution[b, :, f] # [Seq_Len]
            
            # Sort attribution descending to easily access Top K sets
            _, indices_sorted = torch.sort(attr, descending=True)
            
            # 1. For each ratio from (0.1, 0.2, ..., 1.0)
            ratios = np.arange(step, 1.0 + step/10, step)
            
            for r in ratios:
                # 2. k = ratio * num_features, prev_k = (ratio - step) * num_features
                k = int(np.round(r * seq_len))
                prev_ratio = r - step
                prev_k = int(np.round(prev_ratio * seq_len))
                
                # Handle edge case where step might be smaller than 1 feature
                if k == prev_k: continue 

                # 3. Get top prev_k, get top_k (implicitly handled by slicing sorted indices)
                # 4. Get top_r = top_k_set \ top_prev_k_set
                # The indices belonging to this specific "band" are located exactly between prev_k and k
                band_indices = indices_sorted[prev_k : k]
                
                # Create a binary mask for this specific band on the time axis
                band_mask = torch.zeros(seq_len, dtype=torch.bool, device=attr.device)
                band_mask[band_indices] = True
                
                # 5. In top_r (band_mask), find segments
                segments = find_segments(band_mask)
                
                # 6. Assign score for segments
                for (start, end) in segments:
                    segment_indices = torch.arange(start, end + 1, device=attr.device)
                    
                    # Calculate average score of time points inside this specific segment
                    segment_score = torch.mean(attr[segment_indices]).item() # .item() for float
                    
                    # Assign this score to the final map
                    final_segmented_map[b, segment_indices, f] = segment_score
                    
                    # Store segment info for visualization
                    segments_dict[b][f].append((start, end, segment_score))

    return final_segmented_map, segments_dict

# --- NEW VISUALIZATION FUNCTION ---
def visualize_top_segments(raw_data, segments_dict, save_dir, top_percent=0.2, num_samples=5):
    """
    Visualizes the top 20% segments overlaid on the raw time series.
    Features:
    - Thicker, smoother lines (solid_capstyle='round').
    - Professional color palette.
    - Minimalist 'scientific' aesthetic (no borders, light grid).
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # Use actual batch size if smaller than requested
    batch_size = raw_data.shape[0] 
    n_plot = min(batch_size, num_samples)
    
    # Get a vibrant colormap
    cmap = plt.get_cmap('tab10')

    for b in range(n_plot):
        n_feats = raw_data.shape[2]
        
        # Create figure
        fig, axes = plt.subplots(n_feats, 1, figsize=(12, 2.5 * n_feats), sharex=True)
        if n_feats == 1: axes = [axes] 
        
        # Set a clean white background
        fig.patch.set_facecolor('white')

        for f in range(n_feats):
            ax = axes[f]
            ax.set_facecolor('white')
            
            # 1. Plot Base Signal (Neutral, Thicker foundation)
            signal = raw_data[b, :, f]
            ax.plot(signal, color='#cccccc', alpha=1.0, linewidth=2.0, label='Raw Signal', zorder=1)
            
            # 2. Get Segments
            seg_list = segments_dict[b][f]
            
            if len(seg_list) > 0:
                # 3. Sort & Select Top K%
                seg_list_sorted = sorted(seg_list, key=lambda x: x[2], reverse=True)
                num_top = max(1, int(len(seg_list_sorted) * top_percent))
                top_segments = seg_list_sorted[:num_top]
                
                # 4. Highlight Top Segments
                for i, (start, end, score) in enumerate(top_segments):
                    color = cmap(i % 10) 
                    
                    # Extend end by 1 for continuous looking plot if python slicing cuts it short
                    # We ensure we don't go out of bounds
                    plot_end = min(end + 2, len(signal))
                    plot_range = range(start, plot_end)
                    
                    # Plot thicker, rounded lines on top
                    ax.plot(plot_range, signal[plot_range], 
                            color=color, 
                            linewidth=4.5,           # Much Thicker
                            solid_capstyle='round',  # Round ends make it look polished
                            alpha=0.9,
                            zorder=10 + i)           # Ensure it sits on top

            # --- Styling ---
            ax.set_ylabel(f'Feat {f}', fontsize=12, fontweight='bold', color='#333333')
            
            # Remove top and right borders (spines) for a cleaner look
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_color('#888888')
            ax.spines['bottom'].set_color('#888888')
            
            # Add a light grid
            ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='#e0e0e0', alpha=0.8)

            if f == 0:
                ax.set_title(f'Sample {b}: Top {int(top_percent*100)}% Significant Segments', 
                             fontsize=14, fontweight='bold', pad=15, color='#2c3e50')
                
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"sample_{b}_top_segments_thick.png"), dpi=300)
        plt.close()
        print(f"Saved attractive visualization for sample {b} to {os.path.join(save_dir, f'sample_{b}_top_segments_thick.png')}")

class EarlyStopping:
    def __init__(self, patience=15, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def get_args():
    parser = argparse.ArgumentParser(description="Train Amortized Explainer")
    parser.add_argument("--dataset", type=str, default="UWaveGestureLibrary")
    parser.add_argument("--data_root", type=str, default="./data/UEA_multivariate")
    parser.add_argument("--model", type=str, default="DNN")
    parser.add_argument("--dnn_type", type=str, default="FCN")
    parser.add_argument("--enc_in", type=int, required=True)
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--train_epochs", type=int, default=100)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--lambda_fid", type=float, default=None)
    parser.add_argument("--lambda_rob", type=float, default=None)
    
    args = parser.parse_args()
    args.root_path = os.path.join(args.data_root, args.dataset)
    return args

def get_target_model_defaults(new_args):
    """Sets up the configuration for loading the pre-trained classification model."""
    args = Namespace(
        data='UEA', data_root="./data/UEA_multivariate",
        dataset=new_args.dataset, model=new_args.model,
        dnn_type=new_args.dnn_type, seed=new_args.seed,
        lambda_reg=0.1, lambda_div=0.1, epsilon=1.0, num_shapelet=10,
        gating_value=None, pos_weight=False, sbm_cls='linear',
        distance_func='euclidean', beta_schedule='constant',
        memory_efficient=False, lr=new_args.lr, lr_decay=False,
        gradient_accumulation_steps=1, gradient_clip=0,
        batch_size=new_args.batch_size, train_epochs=new_args.train_epochs, num_workers=0,
        patience=50, multi_gpu=False, test_only=True, amp=True,
        task_name='explanation', model_id='test', embed='timeF', freq='h',
        dropout=0.0, activation='gelu', output_attention=False, is_training=False, 
        d_model=512, e_layers=2, factor= 1, n_heads=8, d_ff=2048 # default config for PatchTST
    )
    args.root_path = f"{args.data_root}/{args.dataset}"
    return args

def train_explainer():
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --- STEP 1: LOAD TARGET ORACLE MODEL ---
    tm_args = get_target_model_defaults(args)
    print(tm_args)
    exp = Explanation_Experiment(args=tm_args)
    target_dnn = exp.model.to(device)
    target_dnn.eval()
    for param in target_dnn.parameters():
        param.requires_grad = False
    
    # train_ds = AmortizedExplanationLoader_V2(args, args.root_path, flag='TRAIN', target_set='train')
    # val_ds = AmortizedExplanationLoader_V2(args, args.root_path, flag='TEST', target_set='test')
    # exit()

    # --- STEP 2: LOAD FIDELITY GATES ---
    k_to_train_on = [80, 85, 90, 95]
    fidelity_gates = {}
    gate_dir = f"./fidelity_gate_checkpoints/{tm_args.dataset}/{tm_args.model}/{tm_args.dnn_type}"
    for k in k_to_train_on:
        gate_path = os.path.join(gate_dir, f"{args.dataset}-{args.model}-{args.dnn_type}-k{k}-seed{tm_args.seed}.pt")
        gate = FidelityGateNet(c_in=args.enc_in).to(device)
        gate.load_state_dict(torch.load(gate_path))
        fidelity_gates[k] = gate
    
    # --- STEP 3: INITIALIZE DATA LOADERS ---
    train_ds = AmortizedExplanationLoader_V2(args, args.root_path, flag='TRAIN', target_set='train')
    val_ds = AmortizedExplanationLoader_V2(args, args.root_path, flag='TEST', target_set='test')
    # exit()
    max_seq_len = max(train_ds.max_seq_len, val_ds.max_seq_len)
    
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, 
                              collate_fn=lambda x: amortized_collate_fn(x, max_len=max_seq_len))
    val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, 
                            collate_fn=lambda x: amortized_collate_fn(x, max_len=max_seq_len))
    
    # --- STEP 4: BUILD EXPLAINER MODEL ---
    # c_in = enc_in * 2 (Channel 1: Raw, Channel 2: Dynamic InputXGradient)
    model = AmortizedExplainerFCN(c_in=args.enc_in * 2, c_out=args.enc_in, hidden_dim=args.hidden_dim).to(device)
    
    # --- STEP 5: SETUP OPTIMIZER & LOSS ---
    distillation_criterion = nn.MSELoss() 
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    early_stopping = EarlyStopping(patience=100)
    fidelity_mask_robust_criterion = FidelityConfidence_Mask_Robust_Loss(target_dnn, fidelity_gates).to(device)
    
    os.makedirs(f"./explainer_checkpoints/{args.dataset}", exist_ok=True)
    best_model_path = os.path.join(f"./explainer_checkpoints/{args.dataset}", f"Amortized_{args.dataset}_{args.model}_{args.dnn_type}_v2.pth")
    best_val_loss = float('inf')


    # --- STEP 8: FINAL INFERENCE & SAVE ---
    model.load_state_dict(torch.load(best_model_path))
    model.eval()
    method = 'amortized_attr_v2'
    all_raw_inputs = [] # Store Raw Input for Visualization
    all_amortized_maps = []
    run_times = []
    for time_count in range(1):
        start_time = time.time()
        with torch.no_grad():
            for _, _, raw_x, labels, padding_mask in val_loader:
                all_raw_inputs.append(raw_x.cpu().numpy())
                with torch.enable_grad():
                    raw_x = raw_x.to(device).float().requires_grad_(True)
                    labels = labels.to(device).long()
                    target_logits = target_dnn(raw_x, padding_mask.to(device), None, None)
                    if isinstance(target_logits, tuple): target_logits = target_logits[0]
                    score = torch.gather(target_logits, 1, labels.view(-1, 1)).sum()
                    grads = torch.autograd.grad(score, raw_x)[0].detach()
                    enriched_input = torch.cat([raw_x.detach(), raw_x.detach() * grads], dim=-1)
                all_amortized_maps.append(model(enriched_input).cpu().numpy())
        end_time = time.time()
        average_time = (end_time - start_time) / len(val_ds)
        run_times.append(average_time)
        print(f"Run {time_count+1}: Average InputXGradient time per instance:   {average_time:.6f} seconds")   
    avg_run_time = sum(run_times) / len(run_times)
    std_run_time = np.std(run_times)
    print(f"Average inference time over 5 runs: {avg_run_time:.6f} ± {std_run_time:.6f} seconds")
    
    final_maps = np.concatenate(all_amortized_maps, axis=0)
    final_raw_inputs = np.concatenate(all_raw_inputs, axis=0) # [B, T, C]
    save_dir = os.path.join(f"./times_explanations/{args.dataset}", args.model, args.dnn_type)
    os.makedirs(save_dir, exist_ok=True)
    filename = f"{args.dataset}-{args.seed}-{args.model}-{args.dnn_type}-test_{method}.npy"
    save_path = os.path.join(save_dir, filename)
    # save_path = os.path.join(f"./explanations/{args.dataset}", args.model, args.dnn_type, f"test_amortized_attr_v2.npy")
    np.save(save_path, {'attributions': final_maps})
    print(f"Amortized Explanation saved to {save_path}")

    # final_maps = rearrange(final_maps, 'b t c -> b c t')

    segmented_attribution, segment_set = faithfulness_preserving_segmentation(raw_attribution=final_maps,step=0.1)
    # print(segmented_attribution.shape)

    # --- VISUALIZATION OF TOP 20% SEGMENTS ---
    print("Visualizing Top 20% Segments...")
    vis_dir = os.path.join(save_dir, "visualizations")
    visualize_top_segments(
        raw_data=final_raw_inputs, 
        segments_dict=segment_set, 
        save_dir=vis_dir, 
        top_percent=0.2, 
        num_samples=5
    )


if __name__ == "__main__":
    train_explainer()
