import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import os

# Custom imports
from models.AmortizedFCN import AmortizedExplainerFCN
from data_provider.explanation_data_loader import AmortizedExplanationLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau


# --- Early Stopping Class ---
class EarlyStopping:
    def __init__(self, patience=15, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def get_args():
    parser = argparse.ArgumentParser(description="Train Amortized Explainer")
    
    # Dataset and Path config
    parser.add_argument("--dataset", type=str, default="UWaveGestureLibrary", help="UEA dataset name")
    parser.add_argument("--data_root", type=str, default="./data/UEA_multivariate", help="Path to data")
    
    # Model config
    parser.add_argument("--model", type=str, default="DNN", help="Target model being explained")
    parser.add_argument("--dnn_type", type=str, default="FCN", help="Architecture of target model")
    parser.add_argument("--enc_in", type=int, required=True, help="Number of input channels/features")
    parser.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimensions for explainer FCN")
    
    # Training hyperparams
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--train_epochs", type=int, default=100)
    parser.add_argument("--seed", type=int, default=42)
    
    args = parser.parse_args()
    args.root_path = os.path.join(args.data_root, args.dataset)
    return args

def train_explainer():
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 1. Initialize Loaders
    train_ds = AmortizedExplanationLoader(args, args.root_path, flag='TRAIN', target_set='train')
    val_ds = AmortizedExplanationLoader(args, args.root_path, flag='TEST', target_set='test')
    
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
    
    # 2. Build Model (Input is Raw TS + Saliency)
    model = AmortizedExplainerFCN(c_in=args.enc_in * 2, c_out=args.enc_in, hidden_dim=args.hidden_dim).to(device)
    
    # 3. Setup Regression Objective
    criterion = nn.MSELoss() 
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    # Schedulers
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    early_stopping = EarlyStopping(patience=100)
    
    checkpoint_dir = f"./explainer_checkpoints/{args.dataset}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_model_path = os.path.join(checkpoint_dir, f"Amortized_{args.dataset}_{args.model}_{args.dnn_type}.pth")

    print(f"--- Training Amortized Explainer: {args.dataset} ---")
    best_val_loss = float('inf')

    for epoch in range(args.train_epochs):
        model.train()
        train_loss = 0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device).float(), batch_y.to(device).float()
            optimizer.zero_grad()
            pred_map = model(batch_x)
            loss = criterion(pred_map, batch_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device).float(), batch_y.to(device).float()
                pred_map = model(batch_x)
                val_loss += criterion(pred_map, batch_y).item()

        avg_val = val_loss / len(val_loader)
        # Step Schedulers
        scheduler.step(avg_val)
        early_stopping(avg_val)

        if epoch % 10 == 0 or epoch == args.train_epochs - 1:
            print(f"Epoch {epoch+1}/{args.train_epochs} | Train MSE: {train_loss/len(train_loader):.6f} | Val MSE: {avg_val:.6f}")
        if avg_val < best_val_loss:
            best_val_loss = avg_val
            torch.save(model.state_dict(), best_model_path)
            print(f"Epoch {epoch:03d} | New Best Val MSE: {best_val_loss:.6f}")
        if early_stopping.early_stop:
            print(f"Early stopping triggered at epoch {epoch}. Best Val MSE: {best_val_loss:.4f}")
            break

    # Inference Phase
    model.load_state_dict(torch.load(best_model_path))
    model.eval()
    
    all_amortized_maps = []
    with torch.no_grad():
        for batch_x, _ in val_loader:
            batch_x = batch_x.to(device).float()
            all_amortized_maps.append(model(batch_x).cpu().numpy())
    
    final_maps = np.concatenate(all_amortized_maps, axis=0)
    save_dir = os.path.join(f"./explanations/{args.dataset}", args.model, args.dnn_type)
    os.makedirs(save_dir, exist_ok=True)
    
    save_path = os.path.join(save_dir, f"{args.dataset}-{args.seed}-{args.model}-{args.dnn_type}-test_amortized_explanation.npy")
    np.save(save_path, {'attributions': final_maps})
    print(f"Successfully saved Amortized Map to {save_path}")

if __name__ == "__main__":
    train_explainer()
