import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
import argparse
import random
import numpy as np
from argparse import Namespace

# Custom Imports
from data_provider.bottom_up_mask_loader import BottomUpMaskDataset
from models.Fidelity_gate_net import FidelityGateNet
from exp.experiment_explanation import Explanation_Experiment

def bottom_up_dataset_collate_fn(data, max_len=None):
    """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create
    Args:
        data: len(batch_size) list of tuples (X, y).
            - X: torch tensor of shape (seq_length, feat_dim); variable seq_length.
            - y: torch tensor of shape (num_labels,) : class indices or numerical targets
                (for classification or regression, respectively). num_labels > 1 for multi-task models
        max_len: global fixed sequence length. Used for architectures requiring fixed length input,
            where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s
    Returns:
        X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input)
        targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output)
        target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor
            0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values
        padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding
    """

    batch_size = len(data)
    x_attr_tuple, y_attr_tuple, features, labels = zip(*data)

    # 1. Convert attribution inputs and targets directly to Tensors
    # This turns a tuple of [T, C] tensors into a single [B, T, C] tensor
    X_ATTR = torch.stack(x_attr_tuple, dim=0)
    Y_ATTR = torch.stack(y_attr_tuple, dim=0)

    # print(type(x_attr), type(y_attr), type(features), type(labels))
    # print(x_attr[0].shape, y_attr[0].shape, features[0].shape, labels[0].shape)
    # exit()


    # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension)
    lengths = [X.shape[0] for X in features]  # original sequence length for each time series
    if max_len is None:
        max_len = max(lengths)

    X = torch.zeros(batch_size, max_len, features[0].shape[-1])  # (batch_size, padded_length, feat_dim)
    for i in range(batch_size):
        end = min(lengths[i], max_len)
        X[i, :end, :] = features[i][:end, :]

    targets = torch.stack(labels, dim=0)  # (batch_size, num_labels)

    padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16),
                                 max_len=max_len)  # (batch_size, padded_length) boolean tensor, "1" means keep

    return X_ATTR, Y_ATTR, X, targets, padding_masks


def padding_mask(lengths, max_len=None):
    """
    Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths,
    where 1 means keep element at this position (time step)
    """
    batch_size = lengths.numel()
    max_len = max_len or lengths.max_val()  # trick works because of overloading of 'or' operator for non-boolean types
    return (torch.arange(0, max_len, device=lengths.device)
            .type_as(lengths)
            .repeat(batch_size, 1)
            .lt(lengths.unsqueeze(1)))

# --- Early Stopping ---
class EarlyStopping:
    def __init__(self, patience=15, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience: self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def get_args():
    parser = argparse.ArgumentParser()
    # Masking Specific
    parser.add_argument("--dataset", type=str, default="BasicMotions")
    parser.add_argument("--model", type=str, default="DNN")
    parser.add_argument("--dnn_type", type=str, default="FCN")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--k", type=int, default=5)
    parser.add_argument("--kl_weight", type=float, default=0.1)
    
    # Training Loop
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--patience", type=int, default=100)
    return parser.parse_known_args()[0] # Use known to avoid conflict with target_args



def get_target_model_defaults(gate_args):
    """
    Returns a Namespace containing all target model hyperparameters.
    Synchronizes dataset, model, and seed from the gate_args.
    """
    args = Namespace(
        # Identification & Paths (Synced with current gate training)
        data='UEA',
        data_root="./data/UEA_multivariate",
        dataset=gate_args.dataset,
        model=gate_args.model,
        dnn_type=gate_args.dnn_type,
        seed=gate_args.seed,
        
        # SBM and InterpGN Hyperparameters
        lambda_reg=0.1,
        lambda_div=0.1,
        epsilon=1.0,
        num_shapelet=10,
        gating_value=None,
        pos_weight=False,
        sbm_cls='linear',
        distance_func='euclidean',
        beta_schedule='constant',
        memory_efficient=False,

        # Experiment Config
        lr=5e-3,
        lr_decay=False,
        gradient_accumulation_steps=1,
        gradient_clip=0,
        batch_size=16, # Usually 1 for explanation/inference
        log_interval=20,
        min_epochs=0,
        train_epochs=500,
        num_workers=0,
        patience=50,
        multi_gpu=False,
        test_only=True,
        amp=True,

        # Basic Config
        task_name='explanation',
        model_id='test',
        embed='timeF',
        freq='h',
        
        # DNN Model Configs
        top_k=5,
        num_kernels=6,
        enc_in=7,   # Note: This will likely be overwritten by data_provider
        dec_in=7,
        c_out=7,
        d_model=512,
        n_heads=8,
        e_layers=2,
        d_layers=1,
        d_ff=2048,
        moving_avg=25,
        factor=1,
        distil=True,
        dropout=0.0,
        activation='gelu',
        output_attention=False,

        # TimesNet/Sequence specific
        label_len=48,
        pred_len=96,
        seasonal_patterns='Monthly',
        inverse=False,
        is_training=False
    )
    
    # Final Path Setup
    args.root_path = f"{args.data_root}/{args.dataset}"
    return args

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def train_gate():
    gate_args = get_args()
    tm_args = get_target_model_defaults(gate_args)
    set_seed(gate_args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Properly load Target Model using your Experiment logic
    # tm_args.dataset = gate_args.dataset
    # tm_args.dnn_type = gate_args.dnn_type
    # tm_args.model = gate_args.model
    # tm_args.seed = gate_args.seed
    # tm_args.root_path = f"{tm_args.data_root}/{tm_args.dataset}"

    print(f"Loading Target Oracle: {tm_args.model}-{tm_args.dnn_type}...")
    experiment = Explanation_Experiment(args=tm_args)
    target_model = experiment.model.to(device)
    target_model.eval()
    for param in target_model.parameters():
        param.requires_grad = False

    # 2. Dataset & Dataloaders
    train_ds = BottomUpMaskDataset(gate_args, tm_args, k_percentage=gate_args.k, flag='train')
    test_ds = BottomUpMaskDataset(gate_args, tm_args, k_percentage=gate_args.k, flag='test')
    max_seq_len = max(train_ds.raw_dataset.max_seq_len, test_ds.raw_dataset.max_seq_len)
    
    enc_in = train_ds[0][0].shape[-1]
    train_loader = DataLoader(train_ds, batch_size=gate_args.batch_size, shuffle=True, collate_fn=lambda x: bottom_up_dataset_collate_fn(x, max_len=max_seq_len) )
    test_loader = DataLoader(test_ds, batch_size=gate_args.batch_size, shuffle=False, collate_fn=lambda x: bottom_up_dataset_collate_fn(x, max_len=max_seq_len))
    # train_loader = DataLoader(train_ds, batch_size=gate_args.batch_size, shuffle=True)
    # test_loader = DataLoader(test_ds, batch_size=gate_args.batch_size, shuffle=False)
    
    # 3. Gate Network & Optimization
    gate_net = FidelityGateNet(c_in=enc_in).to(device)
    optimizer = torch.optim.Adam(gate_net.parameters(), lr=gate_args.lr, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    early_stopping = EarlyStopping(patience=gate_args.patience)
    
    # save_path = f"./checkpoints/fidelity_gates/{gate_args.dataset}/Gate_k{gate_args.k}.pth"
    # os.makedirs(os.path.dirname(save_path), exist_ok=True)

        # 1. Build the nested directory path
    checkpoint_dir = os.path.join(
        "fidelity_gate_checkpoints",
        gate_args.dataset,      # e.g., UEA
        gate_args.model,     # e.g., DNN
        gate_args.dnn_type   # e.g., FCN
    )
    os.makedirs(checkpoint_dir, exist_ok=True)
    # 2. Define the filename with k and seed
    # This ensures k=5 and k=10 are stored as separate files
    base_name = f"{gate_args.dataset}-{gate_args.model}-{gate_args.dnn_type}-k{gate_args.k}-seed{gate_args.seed}"
    save_path = os.path.join(checkpoint_dir, f"{base_name}.pt")
    log_path = os.path.join(checkpoint_dir, f"{base_name}.log")

    # Training Loop
    criterion_bce = nn.BCELoss()
    criterion_kl = nn.KLDivLoss(reduction='batchmean', log_target=False)
    # Activation layers as modules
    softmax_op = nn.Softmax(dim=-1)
    log_softmax_op = nn.LogSoftmax(dim=-1)

    # 2. Initialize Log File
    with open(log_path, "w") as f:
        f.write(f"Training Log for {tm_args.dataset} | k={gate_args.k} | KL_wt={gate_args.kl_weight}\n")
        f.write("Epoch,Train_Total_Loss,Val_BCE_Loss\n")
    
    for epoch in range(gate_args.epochs):
        gate_net.train()
        train_loss = 0
        for x_attr, y_mask_gt, raw_ts, label, pad_mask in train_loader:
            # print(type(x_attr), type(y_mask_gt), type(raw_ts), type(pad_mask), type(label))
            # print(x_attr.shape, y_mask_gt.shape, raw_ts.shape, pad_mask.shape, label.shape)
            # exit()

            # print(type(x_attr), type(y_mask_gt), type(raw_ts), type(label), type(pad_mask))
            # exit()

            x_attr, y_mask_gt = x_attr.to(device), y_mask_gt.to(device)
            raw_ts, pad_mask = raw_ts.to(device), pad_mask.to(device)

            optimizer.zero_grad()
            
            # A. Predicted soft mask
            pred_mask = gate_net(x_attr)
            
            # B. Structural Loss (BCE)
            loss_bce = criterion_bce(pred_mask, y_mask_gt)
            
            # 3. Functional Distillation Loss (KL Divergence)
            # We want the output of the Target DNN to be identical for both masks
            with torch.no_grad():
                # What does the Target DNN think when using the EXPERT mask?
                gt_masked_ts = raw_ts * y_mask_gt
                gt_logits = target_model(gt_masked_ts, pad_mask, None, None)
                if isinstance(gt_logits, tuple): gt_logits = gt_logits[0]
                gt_probs = softmax_op(gt_logits) # Target is Probabilities

            # Masked input: Soft gate applied to raw TS
            masked_ts = raw_ts * pred_mask
            masked_out = target_model(masked_ts, pad_mask, None, None)
            if isinstance(masked_out, tuple): masked_out = masked_out[0]
            pred_log_probs = log_softmax_op(masked_out) # Input is Log-Probabilities
            
            loss_kl = criterion_kl(pred_log_probs, gt_probs)

            # Combined Objective
            # total_loss = loss_bce + gate_args.kl_weight * loss_kl
            total_loss = loss_bce
            total_loss.backward()
            optimizer.step()
            train_loss += total_loss.item()

        # Validation (using standard BCE on the mask quality)
        gate_net.eval()
        val_loss = 0
        with torch.no_grad():
            for x_attr, y_mask_gt, _, _, _ in test_loader:
                x_attr, y_mask_gt = x_attr.to(device), y_mask_gt.to(device)
                val_loss += criterion_bce(gate_net(x_attr), y_mask_gt).item()
        
        avg_val = val_loss / len(test_loader)
        scheduler.step(avg_val)
        early_stopping(avg_val)

        # Save to Log File
        if epoch % 10 == 0:
            print(f"Epoch {epoch:03d} | Total Loss: {train_loss/len(train_loader):.5f} | Val BCE: {avg_val:.5f}")
            with open(log_path, "a") as f:
                f.write(f"{epoch},{train_loss/len(train_loader):.6f},{avg_val:.6f}\n")

        if avg_val == early_stopping.best_loss:
            torch.save(gate_net.state_dict(), save_path)
            print(f">>> Saved Best Model for k={gate_args.k} to {save_path}")
            print(f"Epoch {epoch:03d} | Total Loss: {train_loss/len(train_loader):.5f} | Val BCE: {avg_val:.5f} (Saved)")

        if early_stopping.early_stop: break

if __name__ == "__main__":
    train_gate()
