"""Training script for the Interpreter Model.

This script trains the interpreter to explain which input sentences
are most important for generating specific target output sentences.
Supports alternating EBM finetuning for better alignment.
"""

import argparse
import logging
import random
import json

# Add src to path for imports
import sys
import time
from pathlib import Path
from typing import List, Tuple, Dict

import pandas as pd
import torch
from torch.nn import functional
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm
import nltk
from nltk import sent_tokenize
from nltk.tokenize import word_tokenize
import matplotlib.pyplot as plt
import os
from src.energy_model.utils.energy_network import semantic_sentence_split, normalize_sentences
from src.interpreter_model.utils.interpreter_utils import clone_encoder_from_ebm

os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("CUDA devices:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())


# Download required NLTK data
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt", quiet=True)

sys.path.append(str(Path(__file__).parent.parent))

from src.energy_model.config import EBMConfig
from src.energy_model.models import EnergyModel
from src.interpreter_model.config.interpreter_configs import InterpreterConfig
from src.interpreter_model.interpreter import InterpreterModel
from src.interpreter_model.losses import create_interpreter_loss
from src.ebm_training.utils import plot_timeseries, upload_file_to_drive

from src.ebm_training.utils import (
    get_cached_batch as get_cached_batch_cpu, # Alias to clarify it returns CPU tensors
)
from src.ebm_training.finetune import (
    contrastive_loss,
    update_encoding_cache,
    get_cached_batch as get_cached_batch_device, # Alias to clarify it moves to device
)


def setup_logging(log_file: str):
    """Initializes logging to stream to stdout and a specified file."""
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Clear existing handlers to prevent duplicate logs
    if logger.hasHandlers():
        logger.handlers.clear()

    # Create formatter
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

    # Create a handler to print to the console (stdout)
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setLevel(logging.INFO)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    # Create a handler to write to a log file
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)


# ──────────────────── EBM Finetuning (Corrected Logic) ────────────────────

def finetune_ebm_step(
    energy_model: EnergyModel,
    batch: Tuple[List[str], List[str], List[str], List[int]],
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    margin: float,
    grad_clip: float = 1.0,
) -> float:
    """
    Perform one EBM finetuning step using the contrastive objective from finetune.py.
    Objective: energy(question_related, answer) < energy(question, answer)
    """
    energy_model.train()
    optimizer.zero_grad(set_to_none=True)

    prompts, responses, questions_related, _, _ = batch

    # Ensure all necessary texts are in the cache
    all_texts_needed = list(set(prompts + responses + questions_related))
    update_encoding_cache(all_texts_needed, cache, energy_model)

    # Pull cached encodings and move to device using the helper from finetune.py
    # Positive pair: (question_related, answer)
    x_pos_q, x_pos_m = get_cached_batch_device(questions_related, cache, device)
    y_pos_q, y_pos_m = get_cached_batch_device(responses, cache, device)

    # Negative pair: (question, answer)
    x_neg_q, x_neg_m = get_cached_batch_device(prompts, cache, device)
    y_neg_q, y_neg_m = y_pos_q, y_pos_m  # Answers are the same

    # Forward pass
    pos_e = energy_model.forward_from_encoded((x_pos_q, x_pos_m), (y_pos_q, y_pos_m)).flatten()
    neg_e = energy_model.forward_from_encoded((x_neg_q, x_neg_m), (y_neg_q, y_neg_m)).flatten()
    loss = contrastive_loss(pos_e, neg_e, margin)

    # Backward pass
    loss.backward()
    torch.nn.utils.clip_grad_norm_(energy_model.parameters(), grad_clip)
    optimizer.step()

    return loss.item()

def finetune_ebm_imdb_step(
    energy_model: EnergyModel,
    batch: Tuple[List[str], List[str], List[str], List[str], List[int]],
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    margin: float,
    grad_clip: float = 1.0,
) -> float:
    """
    Perform one EBM finetuning step for IMDB mode.
    Objective: energy(high_importance_sentences, prediction) < energy(low_importance_sentences, prediction)
    """
    energy_model.train()
    optimizer.zero_grad(set_to_none=True)

    _, responses, questions_related, questions_negative, _ = batch

    # Ensure all necessary texts are in the cache
    all_texts_needed = list(set(responses + questions_related + questions_negative))
    update_encoding_cache(all_texts_needed, cache, energy_model)

    # Positive pair: (high_importance_text, prediction)
    x_pos_q, x_pos_m = get_cached_batch_device(questions_related, cache, device)
    y_pos_q, y_pos_m = get_cached_batch_device(responses, cache, device)

    # Negative pair: (low_importance_text, prediction)
    x_neg_q, x_neg_m = get_cached_batch_device(questions_negative, cache, device)
    y_neg_q, y_neg_m = y_pos_q, y_pos_m  # Predictions (answers) are the same

    # Forward pass
    pos_e = energy_model.forward_from_encoded((x_pos_q, x_pos_m), (y_pos_q, y_pos_m)).flatten()
    neg_e = energy_model.forward_from_encoded((x_neg_q, x_neg_m), (y_neg_q, y_neg_m)).flatten()
    loss = contrastive_loss(pos_e, neg_e, margin)

    # Backward pass
    loss.backward()
    torch.nn.utils.clip_grad_norm_(energy_model.parameters(), grad_clip)
    optimizer.step()

    return loss.item()

# ──────────────────── Utils ────────────────────

def save_checkpoint(
    epoch: int,
    model: InterpreterModel,
    optimizer: torch.optim.Optimizer,
    training_state: dict,
    output_dir: Path,
) -> Path:
    """Saves a full training checkpoint including all metric histories."""
    checkpoint_payload = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'training_state': training_state,  # Save the entire state dictionary
    }
    checkpoint_path = output_dir / f"checkpoint_epoch_{epoch}_step_{training_state['global_step']}.pt"
    torch.save(checkpoint_payload, checkpoint_path)
    logging.info(f"Saved checkpoint to: {checkpoint_path}")
    return checkpoint_path

def load_checkpoint(
    checkpoint_path: str,
    model: InterpreterModel,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
) -> dict:
    """Loads a training checkpoint and returns the epoch and full training state."""
    logging.info(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    training_state = checkpoint.get('training_state', {})
    training_state["start_epoch"] = checkpoint.get('epoch', 0)
    
    logging.info(f'Resuming from epoch {training_state["start_epoch"] + 1}, global step {training_state.get("global_step", 0)}.')
    return training_state


def set_seed(seed: int = 42) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


def plot_interpreter_metrics(
    training_state: dict,
    output_dir: Path,
) -> None:
    """Plot training metrics for interpreter training.
    
    Args:
        train_losses: Training losses per epoch
        val_losses: Validation losses per epoch
        train_metrics: Additional training metrics per epoch
        val_metrics: Additional validation metrics per epoch  
        ebm_finetune_losses: EBM finetuning losses per step
        output_dir: Directory to save plots
    """
    try:
        train_losses = training_state["epoch_train_losses"]
        val_losses = training_state["epoch_val_losses"]
        train_metrics = training_state["epoch_train_metrics"]
        val_metrics = training_state["epoch_val_metrics"]
        ebm_finetune_losses = training_state["ebm_finetune_losses"]
        
        # 1) Training and validation losses
        if train_losses and val_losses:
            plot_timeseries(
                {"train": train_losses, "val": val_losses},
                str(output_dir), "Train vs Val", "Loss", "Epoch #"
            )
        
        # 2) Individual training and validation loss plots
        if train_losses:
            plot_timeseries(
                {"train": train_losses},
                str(output_dir), "Train", "Loss", "Epoch #"
            )
            
        if val_losses:
            plot_timeseries(
                {"val": val_losses},
                str(output_dir), "Val", "Loss", "Epoch #"
            )
        
        # 3) Energy difference metrics if available
        if train_metrics:
            energy_diffs_train = []
            for metrics in train_metrics:
                energy_diffs_train.append(metrics.get('energy_diff_mean', 0.0))
            
            if energy_diffs_train:
                plot_timeseries(
                    {"train": energy_diffs_train},
                    str(output_dir), "Train", "Energy Difference", "Epoch #"
                )
        
        if val_metrics:
            energy_diffs_val = []
            for metrics in val_metrics:
                energy_diffs_val.append(metrics.get('energy_diff_mean', 0.0))
                
            if energy_diffs_val:
                plot_timeseries(
                    {"val": energy_diffs_val},
                    str(output_dir), "Val", "Energy Difference", "Epoch #"
                )
        
        # 4) EBM finetuning losses
        if ebm_finetune_losses:
            plot_timeseries(
                {"ebm_finetune": ebm_finetune_losses},
                str(output_dir), "EBM Finetuning", "Loss", "Step #"
            )
            
        logging.info(f"Saved interpreter training plots to: {output_dir}")
        
    except Exception as e:
        logging.error(f"Could not generate interpreter training plots: {e}")


def save_training_summary(
    training_state: dict,
    output_dir: Path,
    training_time: float,
    config_info: Dict[str, any],
) -> None:
    """Save comprehensive training summary to CSV and text files.
    
    Args:
        train_losses: Training losses per epoch
        val_losses: Validation losses per epoch
        train_metrics: Additional training metrics per epoch
        val_metrics: Additional validation metrics per epoch
        ebm_finetune_losses: EBM finetuning losses per step
        output_dir: Directory to save summary
        training_time: Total training time in seconds
        config_info: Configuration information dictionary
    """
    try:
        train_losses = training_state["epoch_train_losses"]
        val_losses = training_state["epoch_val_losses"]
        train_metrics = training_state["epoch_train_metrics"]
        val_metrics = training_state["epoch_val_metrics"]
        ebm_finetune_losses = training_state["ebm_finetune_losses"]
        
        # 1) Save epoch-level metrics to CSV
        epoch_data = []
        for epoch in range(len(train_losses)):
            row = {
                'epoch': epoch + 1,
                'train_loss': train_losses[epoch],
                'val_loss': val_losses[epoch] if epoch < len(val_losses) else float('nan'),
            }
            
            # Add training metrics
            if epoch < len(train_metrics):
                for key, value in train_metrics[epoch].items():
                    row[f'train_{key}'] = value
                    
            # Add validation metrics
            if epoch < len(val_metrics):
                for key, value in val_metrics[epoch].items():
                    row[f'val_{key}'] = value
                    
            epoch_data.append(row)
        
        epoch_df = pd.DataFrame(epoch_data)
        epoch_csv_path = output_dir / "epoch_metrics.csv"
        epoch_df.to_csv(epoch_csv_path, index=False)
        
        # 2) Save EBM finetuning losses if available
        if ebm_finetune_losses:
            ebm_data = [{
                'step': i + 1,
                'ebm_finetune_loss': loss
            } for i, loss in enumerate(ebm_finetune_losses)]
            
            ebm_df = pd.DataFrame(ebm_data)
            ebm_csv_path = output_dir / "ebm_finetune_losses.csv"
            ebm_df.to_csv(ebm_csv_path, index=False)
        
        # 3) Save text summary
        summary_path = output_dir / "training_summary.txt"
        with open(summary_path, "w") as f:
            f.write("Interpreter Training Summary\n")
            f.write("============================\n\n")
            
            # Training results
            f.write(f"Training Time: {training_time:.1f}s ({training_time / 60:.1f} minutes)\n")
            f.write(f"Total Epochs: {len(train_losses)}\n")
            f.write(f"Final Train Loss: {train_losses[-1]:.4f}\n")
            f.write(f"Final Val Loss: {val_losses[-1]:.4f}\n")
            f.write(f"Best Train Loss: {min(train_losses):.4f}\n")
            f.write(f"Best Val Loss: {min(val_losses):.4f}\n")
            
            if ebm_finetune_losses:
                f.write(f"\nEBM Finetuning Steps: {len(ebm_finetune_losses)}\n")
                f.write(f"Final EBM Finetune Loss: {ebm_finetune_losses[-1]:.4f}\n")
                f.write(f"Best EBM Finetune Loss: {min(ebm_finetune_losses):.4f}\n")
            
            # Configuration
            f.write(f"\nConfiguration:\n")
            for key, value in config_info.items():
                f.write(f"{key}: {value}\n")
                
        logging.info(f"Saved training summary to: {summary_path}")
        
    except Exception as e:
        logging.error(f"Could not save training summary: {e}")


class InterpreterDataset(Dataset):
    """Dataset for interpreter training with (input, output, target_index) triplets."""

    def __init__(self, csv_path: str, ebm_finetune_enabled: bool, imdb_mode: bool = False):
        """Initialize dataset from CSV file."""
        self.df = pd.read_csv(csv_path)
        self.imdb_mode = imdb_mode

        if self.imdb_mode:
            self._init_imdb(self.df)
        else:
            self._init_standard(self.df, ebm_finetune_enabled)

    def _init_imdb(self, df: pd.DataFrame):
        """Process data for IMDB sentiment analysis mode."""
        needed = ["review_text", "formatted_prediction", "sentence_analysis"]
        for c in needed:
            if c not in df.columns:
                raise ValueError(f"IMDB mode requires missing column '{c}' in CSV")

        df = df.dropna(subset=needed).copy()
        for c in needed:
            df[c] = df[c].astype(str).str.strip()

        self.prompts = []             # Full review_text for interpreter
        self.responses = []           # formatted_prediction
        self.questions_related = []   # High importance sentences (EBM positive)
        self.questions_negative = []  # Low importance sentences (EBM negative)
        self.target_indices = []      # Default to -1

        filtered_count = 0
        for idx, row in df.iterrows():
            try:
                sentence_data = json.loads(row["sentence_analysis"])
                if len(sentence_data) < 5:
                    filtered_count += 1
                    continue

                sentence_data.sort(key=lambda x: x["importance_score"], reverse=True)
                split_point = int(0.8 * len(sentence_data))
                
                high_importance = sentence_data[:split_point]
                high_importance_text = " ".join([s["sentence"] for s in high_importance])

                low_importance = sentence_data[split_point:]
                low_importance_text = " ".join([s["sentence"] for s in low_importance])
                
                self.prompts.append(row["review_text"])
                self.responses.append(row["formatted_prediction"])
                self.questions_related.append(high_importance_text)
                self.questions_negative.append(low_importance_text)
                self.target_indices.append(-1)

            except (json.JSONDecodeError, KeyError, TypeError) as e:
                logging.warning(f"Skipping row {idx} in IMDB mode due to parsing error: {e}")
                continue
        logging.info(f"IMDB mode: Filtered {filtered_count} records with < 5 sentences.")

    def _init_standard(self, df: pd.DataFrame, ebm_finetune_enabled: bool):
        """Process data for standard QA mode."""
        df = self.df.rename(columns={
            "question": "prompt",
            "answer": "response",
        })

        # Add target_index column if not present
        if "target_index" not in self.df.columns:
            df["target_index"] = -1  # Default to last sentence

        # Handle EBM finetuning columns
        if ebm_finetune_enabled and not self.has_ebm_cols:
            raise ValueError(
                "EBM finetuning is enabled (--ebm_finetune), but the CSV is missing the "
                "required 'question_related' column."
            )
        
        # Add dummy column if not present and finetuning is off
        if "question_related" not in df.columns:
            df["question_related"] = ""

        self.prompts = df["prompt"].tolist()
        self.responses = df["response"].tolist()
        self.questions_related = df["question_related"].tolist()
        self.questions_negative = df["prompt"].tolist() # For EBM, negative is the original prompt
        self.target_indices = df["target_index"].astype(int).tolist()


    def __len__(self) -> int:
        return len(self.prompts)

    def __getitem__(self, idx: int) -> Tuple[str, str, str, str, int]:
        return (
            self.prompts[idx],
            self.responses[idx],
            self.questions_related[idx],
            self.questions_negative[idx],
            self.target_indices[idx],
        )


def collate_fn(batch: List[Tuple[str, str, str, str, int]]) -> Tuple[List[str], List[str], List[str], List[str], List[int]]:
    """Collate function for DataLoader."""
    prompts, responses, q_related, q_negative, target_indices = zip(*batch)
    return list(prompts), list(responses), list(q_related), list(q_negative), list(target_indices)

def save_explanation_csv(
    input_text: str,
    output_text: str,
    scores_tensor: torch.Tensor,  
    n_sentences: int,
    out_dir: Path,
    tag: str = "train",                # "train" or "val"
    epoch: int = 0,
    step: int = 0,
    sample_idx: int = 0,
) -> Path:
    """
    Saves a 2-row CSV:
      Row 1 = input sentences (S columns)
      Row 2 = importance scores (S columns)
    """
    # Split & normalize to exactly n_sentences so columns align
    sentences = semantic_sentence_split(input_text)
    sentences = normalize_sentences(sentences, n_sentences)

    # Convert scores -> list[float]
    scores = scores_tensor.detach().cpu().tolist()

    # Build a 2-row DataFrame
    cols = [f"s{j+1}" for j in range(n_sentences)]
    df = pd.DataFrame([sentences, scores], columns=cols)
    df.index = ["sentence", "score"]  # just for readability in the CSV
    df["output_text"] = [output_text, float("nan")]

    # File name: one file per sample
    out_dir.mkdir(parents=True, exist_ok=True)
    csv_path = out_dir / f"{tag}_explanation_epoch{epoch:03d}_step{step:06d}_idx{sample_idx:03d}.csv"
    df.to_csv(csv_path, index=True)
    return csv_path


def train_epoch(
    model: InterpreterModel,
    dataloader: DataLoader,
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    training_state: dict,
    args: argparse.Namespace,
) -> Tuple[float, dict]:
    """Train for one epoch."""
    model.train()

    pbar = tqdm(dataloader, desc=f"Epoch {epoch} [Train]", leave=False)
    num_batches_before = len(training_state["batch_train_losses"])

    for batch_idx, (input_texts, output_texts, _, _, target_indices) in enumerate(pbar): # <-- MODIFIED SIGNATURE
        if batch_idx < training_state["resume_batch_idx"]:
            continue

        optimizer.zero_grad()

        cache = training_state["encoding_cache"]
        stable_texts = list(set(input_texts + output_texts))
        cache_misses = [t for t in stable_texts if t not in cache]
        if cache_misses:
            update_encoding_cache_cloned(cache_misses, cache, model)

        x_emb, x_mask = get_cached_batch_cpu(input_texts, cache)
        y_emb, y_mask = get_cached_batch_cpu(output_texts, cache)

        x_emb = x_emb.to(device, non_blocking=True)
        x_mask = x_mask.to(device, non_blocking=True)
        y_emb = y_emb.to(device, non_blocking=True)
        y_mask = y_mask.to(device, non_blocking=True)

        # --- NEW: fast path ---
        importance_scores, selected_energies, unselected_energies = model.evaluate_from_encoded(
            (x_emb, x_mask),
            (y_emb, y_mask),
            target_indices=target_indices,
            input_texts=input_texts,
            output_texts=output_texts,
        )

        # --- Save per-sample 2-row CSVs with sentences & scores ---
        if epoch == args.save_every_n_epochs or (epoch % args.save_every_n_epochs == 0):  # tweak condition as you wish
            for i, text in enumerate(input_texts):
                save_explanation_csv(
                    input_text=text,
                    output_text=output_texts[i],
                    scores_tensor=importance_scores[i],              # shape [S]
                    n_sentences=n_sentences,
                    out_dir=Path(args.output_dir) / "explanations",  # e.g., interpreter_results/explanations
                    tag="train",
                    epoch=epoch,
                    step=training_state["global_step"],
                    sample_idx=i,
                )


        # Compute loss (unchanged)
        if hasattr(loss_fn, "forward"):
            if "importance_scores" in loss_fn.forward.__code__.co_varnames:
                loss, metrics = loss_fn(selected_energies, unselected_energies, importance_scores)
            else:
                loss, metrics = loss_fn(selected_energies, unselected_energies)
        else:
            loss = loss_fn(selected_energies, unselected_energies)
            metrics = {"loss": float(loss.detach().item())}

        # Backward + step
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        training_state['global_step'] += 1

        # Append per-batch results for complete history
        training_state["batch_train_losses"].append(loss.item())
        training_state["batch_train_metrics"].append(metrics)

        current_epoch_losses = training_state["batch_train_losses"][num_batches_before:]
        avg_loss_so_far = sum(current_epoch_losses) / len(current_epoch_losses)
        # Update progress bar
        pbar.set_postfix(
            {
                "loss": f"{loss.item():.4f}",
                "avg_loss": f"{avg_loss_so_far:.4f}",
                "energy_diff": f"{metrics.get('energy_diff_mean', 0):.3f}",
            }
        )

        # if args.save_every_n_batches > 0 and training_state['global_step'] % args.save_every_n_batches == 0:
        #     checkpoint_path = save_checkpoint(epoch, model, optimizer, training_state, Path(args.output_dir))
        #     if args.upload_to_gdrive:
        #         upload_file_to_drive(
        #             args.gdrive_folder_id, checkpoint_path, args.gdrive_creds_path, args.output_dir
        #         )
             
    # After the epoch, calculate and store the epoch-level average metrics
    current_epoch_losses = training_state["batch_train_losses"][num_batches_before:]
    current_epoch_metrics = training_state["batch_train_metrics"][num_batches_before:]
    
    training_state["epoch_train_losses"].append(sum(current_epoch_losses) / len(current_epoch_losses) if current_epoch_losses else 0.0)
    avg_metrics = {k: sum(d[k] for d in current_epoch_metrics) / len(current_epoch_metrics) for k in current_epoch_metrics[0]} if current_epoch_metrics else {}
    training_state["epoch_train_metrics"].append(avg_metrics)

    training_state["resume_batch_idx"] = 0 # Reset for the next epoch


def validate_epoch(
    model: InterpreterModel,
    dataloader: DataLoader,
    loss_fn: torch.nn.Module,
    device: torch.device,
    epoch: int,
    training_state: dict,
) -> Tuple[float, dict]:
    """Validate for one epoch."""
    model.eval()

    pbar = tqdm(dataloader, desc=f"Epoch {epoch} [Val]", leave=False)
    num_batches_before = len(training_state["batch_val_losses"])

    with torch.no_grad():
        for batch_idx, (input_texts, output_texts, _, _, target_indices) in enumerate(pbar): # <-- MODIFIED SIGNATURE

            cache = training_state["encoding_cache"]
            stable_texts = list(set(input_texts + output_texts))
            cache_misses = [t for t in stable_texts if t not in cache]
            if cache_misses:
                update_encoding_cache_cloned(cache_misses, cache, model)

            x_emb, x_mask = get_cached_batch_cpu(input_texts, cache)
            y_emb, y_mask = get_cached_batch_cpu(output_texts, cache)

            x_emb = x_emb.to(device, non_blocking=True)
            x_mask = x_mask.to(device, non_blocking=True)
            y_emb = y_emb.to(device, non_blocking=True)
            y_mask = y_mask.to(device, non_blocking=True)

            importance_scores, selected_energies, unselected_energies = model.evaluate_from_encoded(
                (x_emb, x_mask),
                (y_emb, y_mask),
                target_indices=target_indices,
                input_texts=input_texts,
                output_texts=output_texts,
            )

            for i, text in enumerate(input_texts):
                save_explanation_csv(
                    input_text=text,
                    output_text=output_texts[i],
                    scores_tensor=importance_scores[i],
                    n_sentences=model.n_sentences,
                    out_dir=Path(training_state.get("output_dir", "interpreter_results")) / "explanations",
                    tag="val",
                    epoch=epoch,
                    step=len(training_state["batch_val_losses"]),
                    sample_idx=i,
                )


            if hasattr(loss_fn, "forward"):
                if "importance_scores" in loss_fn.forward.__code__.co_varnames:
                    loss, metrics = loss_fn(selected_energies, unselected_energies, importance_scores)
                else:
                    loss, metrics = loss_fn(selected_energies, unselected_energies)
            else:
                loss = loss_fn(selected_energies, unselected_energies)
                metrics = {"loss": float(loss.detach().item())}


            # Update metrics
            training_state["batch_val_losses"].append(loss.item())
            training_state["batch_val_metrics"].append(metrics)

            current_epoch_losses = training_state["batch_val_losses"][num_batches_before:]
            avg_loss_so_far = sum(current_epoch_losses) / len(current_epoch_losses)

            # Update progress bar
            pbar.set_postfix(
                {
                    "val_loss": f"{loss.item():.4f}",
                    "avg_loss": f"{avg_loss_so_far:.4f}",
                    "energy_diff": f"{metrics.get('energy_diff_mean', 0):.3f}",
                }
            )

    # Average metrics
    current_epoch_losses = training_state["batch_val_losses"][num_batches_before:]
    current_epoch_metrics = training_state["batch_val_metrics"][num_batches_before:]
    
    training_state["epoch_val_losses"].append(sum(current_epoch_losses) / len(current_epoch_losses) if current_epoch_losses else 0.0)
    avg_metrics = {k: sum(d[k] for d in current_epoch_metrics) / len(current_epoch_metrics) for k in current_epoch_metrics[0]} if current_epoch_metrics else {}
    training_state["epoch_val_metrics"].append(avg_metrics)

def update_encoding_cache_cloned(
    texts: List[str],
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
) -> None:
    """Encodes a list of texts and updates the cache with the results.

    The encoded tensors are stored on the CPU to conserve GPU memory.

    Args:
        texts (List[str]): A list of text strings to encode and cache.
        cache (Dict[str, Tuple[torch.Tensor, torch.Tensor]]): The cache
            dictionary to update in-place.
        model (torch.nn.Module): The model, which must have a `text_encoder` attribute.
    """
    if not texts:
        return
        
    encoded_embs, encoded_masks = model.text_encoder(texts)
    for i, text in enumerate(texts):
        cache[text] = (encoded_embs[i].cpu(), encoded_masks[i].cpu())

def main():
    """Main training function."""
    parser = argparse.ArgumentParser(description="Train Interpreter Model")

    # Data arguments
    parser.add_argument("--csv", required=True, help="CSV file with training data")
    parser.add_argument(
        "--ebm_checkpoint", required=True, help="Path to trained EBM checkpoint"
    )

    parser.add_argument("--ebm_self_attention_layers", type=int, default=2)
    parser.add_argument(
        "--ebm_cross_attention_layers", type=int, default=6)
    parser.add_argument(
        "--interp_cross_attention_layers", type=int, default=6)
    
    # Training arguments
    parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument(
        "--val_split", type=float, default=0.2, help="Validation split ratio"
    )

    # Model arguments
    parser.add_argument(
        "--encoder_sharing",
        choices=["shared", "cloned"],
        default="shared",
        help="Encoder sharing strategy",
    )
    parser.add_argument(
        "--freeze_encoder",
        action="store_true",
        help="Freeze encoder weights during training",
    )
    parser.add_argument(
        "--target_embedding",
        choices=["isolated", "contextualized"],
        default="contextualized",
        help="Target embedding strategy",
    )

    # Loss arguments
    parser.add_argument(
        "--loss_type",
        choices=["contrastive", "regression", "regularized"],
        default="contrastive",
        help="Loss function type",
    )
    parser.add_argument("--margin", type=float, default=1.0, help="Contrastive loss margin")
    parser.add_argument(
        "--infonce_temperature", 
        type=float, 
        default=0.1, 
        help="InfoNCE temperature parameter"
    )
    parser.add_argument(
        "--regression_loss_type",
        choices=["mse", "mae", "huber"],
        default="mse",
        help="Regression loss subtype (mse, mae, huber)",
    )

    # Masking arguments
    parser.add_argument(
        "--masking_type",
        choices=["hard", "soft"],
        default="hard",
        help="Main masking strategy (hard or soft)",
    )
    parser.add_argument(
        "--hard_mask_method",
        choices=["top_k", "threshold"],
        default="top_k",
        help="Hard masking method",
    )
    parser.add_argument("--top_k", type=int, default=8, help="Number of sentences to keep")
    parser.add_argument("--threshold", type=float, default=0.5, help="Importance threshold")
    parser.add_argument(
        "--soft_mask_method",
        choices=["multiply", "interpolate"],
        default="multiply",
        help="Soft masking method",
    )
    parser.add_argument(
        "--gumbel_temperature",
        type=float,
        default=1.0,
        help="Temperature for Gumbel Softmax in importance scoring (default: 1.0)",
    )
    parser.add_argument(
        "--gumbel_k",
        type=int,
        default=4,
        help="Number of times to sample Gumbel noise for robust top-k selection (default: 4)",
    )

    # EBM Finetuning arguments
    parser.add_argument(
        "--ebm_finetune",
        action="store_true",
        help="Enable EBM finetuning during interpreter training",
    )
    parser.add_argument(
        "--ebm_finetune_steps",
        type=int,
        default=5,
        help="Number of EBM finetuning steps per interpreter training cycle",
    )
    parser.add_argument(
        "--ebm_lr",
        type=float,
        default=5e-6,
        help="Learning rate for EBM finetuning",
    )
    parser.add_argument(
        "--finetune_schedule",
        choices=["every_epoch", "every_n_steps", "end_of_training"],
        default="every_epoch",
        help="When to perform EBM finetuning",
    )
    parser.add_argument(
        "--finetune_step_interval",
        type=int,
        default=100,
        help="Interval for EBM finetuning when using every_n_steps schedule",
    )
    # parser.add_argument(
    #     "--ebm_loss_strategy",
    #     choices=["sum", "weighted_sum", "sequential", "infonce", "infonce_expanded"],
    #     default="infonce_expanded",
    #     help="EBM loss strategy (same as main EBM training)",
    # )
    parser.add_argument(
        "--ebm_temperature",
        type=float,
        default=0.1,
        help="Temperature for InfoNCE loss in EBM finetuning",
    )
    parser.add_argument(
        "--ebm_margin",
        type=float,
        default=0.5,
        help="Margin for contrastive loss in EBM finetuning",
    )
    # parser.add_argument(
    #     "--ebm_off_context_weight",
    #     type=float,
    #     default=2.0,
    #     help="Weight for off_context negatives in weighted_sum strategy",
    # )

    # Other arguments
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument(
        "--output_dir", default="interpreter_results", help="Output directory"
    )

    # parser.add_argument('--save-every-n-batches', type=int, default=0, help="Save a checkpoint every N batches. Set to 0 to disable.")
    parser.add_argument(
    '--save-every-n-epochs', type=int, default=1,
    help='Save a full checkpoint every N epochs (default: 1). Set 0 to disable.'
    )
    parser.add_argument('--resume-from-checkpoint', type=str, default=None, help="Path to checkpoint file to resume training from.")
    parser.add_argument('--upload-to-gdrive', action='store_true', help="Enable uploading checkpoints and models to Google Drive.")
    parser.add_argument('--gdrive-folder-id', type=str, default=None, help="The ID of the Google Drive folder to upload files to.")
    parser.add_argument('--gdrive-creds-path', type=str, default=None, help="Path to the PyDrive2 credentials file (creds.dat).")

    parser.add_argument("--imdb-mode", action="store_true", help="Enable IMDB sentiment analysis mode.")
    parser.add_argument('--softmax-type', type=str, default="gumbel")
    parser.add_argument('--ebm-is-finetuned', action='store_true')

    args = parser.parse_args()

    # Setup logging
    setup_logging("interpreter_training.log")

    set_seed(args.seed)
    torch.backends.cudnn.benchmark = True
    device = torch.device(args.device)

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True)

    best_model_path = output_dir / "best_interpreter.pt"

    # Load dataset
    logging.info("Loading dataset...")
    if args.imdb_mode:
        logging.info("Running in IMDB mode.")
    full_dataset = InterpreterDataset(args.csv, args.ebm_finetune, args.imdb_mode)
    val_len = int(len(full_dataset) * args.val_split)
    train_len = len(full_dataset) - val_len
    train_dataset, val_dataset = random_split(full_dataset, [train_len, val_len])

    train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2,
    )

    logging.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

    # Load pre-trained EBM
    logging.info(f"Loading EBM from {args.ebm_checkpoint}...")
    ebm_config = EBMConfig(self_attention_n_layers = args.ebm_self_attention_layers, 
                           cross_attention_n_layers = args.ebm_cross_attention_layers)
    energy_model = EnergyModel(ebm_config).to(device)
    if args.ebm_is_finetuned:
        checkpoint = torch.load(args.ebm_checkpoint, map_location=device)
        energy_model.load_state_dict(checkpoint['model_state_dict'])
    else:
        energy_model.load_state_dict(torch.load(
            args.ebm_checkpoint, map_location=device, weights_only=True
        ))
    energy_model.eval()  # Keep EBM in eval mode

    # Freeze EBM parameters if --freeze_encoder is used
    if not args.ebm_finetune:
        logging.info("Freezing all EBM parameters...")
        for param in energy_model.parameters():
            param.requires_grad = False
        logging.info(f"EBM frozen: {sum(p.numel() for p in energy_model.parameters() if not p.requires_grad):,} parameters")
    else:
        logging.info(f"EBM unfrozen: {sum(p.numel() for p in energy_model.parameters() if p.requires_grad):,} trainable parameters")

    # Create interpreter model
    logging.info("Creating interpreter model...")
    interp_config = InterpreterConfig(
        encoder_sharing_strategy=args.encoder_sharing,
        freeze_encoder_in_interp=args.freeze_encoder,
        cross_attention_layers=args.interp_cross_attention_layers,
        target_embedding=args.target_embedding,
        masking_type=args.masking_type,
        hard_mask_method=args.hard_mask_method,
        soft_mask_method=args.soft_mask_method,
        top_k=args.top_k,
        threshold=args.threshold,
        loss_type=args.loss_type,
        regression_loss_type=args.regression_loss_type,
        energy_margin=args.margin,
        gumbel_temperature=args.gumbel_temperature,
        gumbel_k=args.gumbel_k,
        softmax_type = args.softmax_type,
    )

    global n_sentences
    n_sentences = ebm_config.n_sentences

    interpreter = InterpreterModel(
        energy_model=energy_model, config=interp_config, n_sentences=n_sentences
    ).to(device)

    # Report parameter counts
    total_interpreter_params = sum(p.numel() for p in interpreter.parameters())
    trainable_interpreter_params = sum(p.numel() for p in interpreter.parameters() if p.requires_grad)

    logging.info(f"Interpreter total parameters: {total_interpreter_params:,}")
    logging.info(f"Interpreter trainable parameters: {trainable_interpreter_params:,}")

    # Additional debugging: check if EBM parameters are still trainable
    if not args.ebm_finetune:
        ebm_trainable = sum(p.numel() for p in energy_model.parameters() if p.requires_grad)
        if ebm_trainable > 0:
            logging.warning(f"WARNING: EBM still has {ebm_trainable:,} trainable parameters despite freeze_encoder=True")
        else:
            logging.info("✓ EBM is properly frozen (0 trainable parameters)")

    # Debugging: check encoder parameter sharing
    if interp_config.encoder_sharing_strategy == "shared":
        logging.info("Using shared encoders - EBM and Interpreter share the same encoder instances")
        if args.freeze_encoder:
            logging.info("✓ Shared encoders are frozen at EBM level")
    else:  # cloned
        logging.info("Using cloned encoders - Interpreter has independent copies")
        if interp_config.freeze_encoder_in_interp:
            logging.info("✓ Cloned encoders are frozen in interpreter")
        else:
            logging.info("Cloned encoders are trainable in interpreter")

    # Create loss function and optimizer
    loss_fn = create_interpreter_loss(
        args.loss_type,
        margin=args.margin,
        temperature=args.infonce_temperature,
        contrastive_weight=1.0,
        sparsity_weight=0.1,
        regression_loss_type=args.regression_loss_type,
    )

    optimizer = torch.optim.AdamW(interpreter.parameters(), lr=args.lr)

    # Training loop
    training_state = {
        "batch_train_losses": [], "batch_val_losses": [],
        "batch_train_metrics": [], "batch_val_metrics": [],
        "epoch_train_losses": [], "epoch_val_losses": [],
        "epoch_train_metrics": [], "epoch_val_metrics": [],
        "ebm_finetune_losses": [],
        "best_val_loss": float("inf"),
        "global_step": 0, "start_epoch": 1, "resume_batch_idx": 0,
        "encoding_cache": {}
    }

    if args.resume_from_checkpoint:
        loaded_state = load_checkpoint(args.resume_from_checkpoint, interpreter, optimizer, device)
        training_state.update(loaded_state)  # Overwrite defaults with loaded history

        if training_state["global_step"] > 0 and training_state["global_step"] % len(train_loader) != 0:
            training_state["resume_batch_idx"] = training_state["global_step"] % len(train_loader)
            logging.info(f'Resuming training from epoch {training_state["start_epoch"]}, batch {training_state["resume_batch_idx"] + 1}.')
        elif training_state["global_step"] > 0:
            training_state["start_epoch"] += 1
            logging.info(f'Resuming training from the start of epoch {training_state["start_epoch"]}.')

    logging.info(f'Starting training from epoch {training_state["start_epoch"]} for {args.epochs} total epochs...')
    start_time = time.time()

    for epoch in range(training_state["start_epoch"], args.epochs + 1):
        epoch_start = time.time()

        # Training
        train_epoch(
            interpreter, train_loader, loss_fn, optimizer, device, epoch, training_state, args
        )

        # Validation
        validate_epoch(
            interpreter, val_loader, loss_fn, device, epoch, training_state
        )

        current_val_loss = training_state["epoch_val_losses"][-1]
        logging.info(f"Epoch {epoch}/{args.epochs} | Train Loss: {training_state['epoch_train_losses'][-1]:.4f} | Val Loss: {current_val_loss:.4f}")

        if current_val_loss < training_state["best_val_loss"]:
            training_state["best_val_loss"] = current_val_loss
            best_model_path = output_dir / "best_interpreter.pt"
            torch.save(interpreter.state_dict(), best_model_path)
            logging.info(f"New best model saved with val_loss: {current_val_loss:.4f}")
            if args.upload_to_gdrive:
                upload_file_to_drive(args.gdrive_folder_id, best_model_path, args.gdrive_creds_path, output_dir)

        # EBM Finetuning (if enabled)
        ebm_epoch_losses = []
        if args.ebm_finetune:
            should_finetune = (
                args.finetune_schedule == "every_epoch" or
                (args.finetune_schedule == "every_n_steps" and
                 epoch % (args.finetune_step_interval // len(train_loader)) == 0) or
                (args.finetune_schedule == "end_of_training" and epoch == args.epochs)
            )

            if should_finetune:
                logging.info(f"Performing EBM finetuning at epoch {epoch}...")
                energy_model.train()  # Set to train mode for finetuning
                ebm_optimizer = torch.optim.Adam(energy_model.parameters(), lr=args.ebm_lr)

                for finetune_step in range(args.ebm_finetune_steps):
                    # Sample a batch for EBM finetuning
                    batch = next(iter(train_loader))

                    # Perform EBM finetuning step using exact EBM training pipeline
                    ebm_loss = finetune_ebm_step(
                        energy_model=energy_model,
                        batch=batch,
                        optimizer=ebm_optimizer,
                        device=device,
                        cache=training_state["encoding_cache"],
                        margin=args.ebm_margin,
                    )

                    ebm_epoch_losses.append(ebm_loss)
                    training_state["ebm_finetune_losses"].append(ebm_loss)

                    if finetune_step % 2 == 0:  # Log every 2 steps
                        logging.info(
                            f"  EBM finetune step {finetune_step + 1}/"
                            f"{args.ebm_finetune_steps}, Loss: {ebm_loss:.4f}"
                        )

                energy_model.eval()  # Set back to eval mode

        epoch_time = time.time() - epoch_start

        # Get the latest results for the completed epoch from the training_state
        train_loss = training_state["epoch_train_losses"][-1]
        val_loss = training_state["epoch_val_losses"][-1]
        train_metrics = training_state["epoch_train_metrics"][-1]
        val_metrics = training_state["epoch_val_metrics"][-1]

        # Enhanced logging with additional metrics
        energy_diff_train = train_metrics.get("energy_diff_mean", 0.0)
        energy_diff_val = val_metrics.get("energy_diff_mean", 0.0)
        ebm_loss_info = f", EBM Loss: {ebm_epoch_losses[-1]:.4f}" if ebm_epoch_losses else ""
        
        logging.info(
            f"Epoch {epoch}/{args.epochs} | "
            f"Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f} | "
            f"Energy Diff (T/V): {energy_diff_train:.3f}/{energy_diff_val:.3f}"
            f"{ebm_loss_info} | "
            f"Time: {epoch_time:.1f}s"
        )

        # Save metrics for plotting
        log_entry = (
            f"Epoch {epoch}/{args.epochs} | "
            f"Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f} | "
            f"Energy Diff: {energy_diff_val:.3f}"
            f"{ebm_loss_info} | "
            f"Time: {epoch_time:.1f}s"
        )

        with open(output_dir / "training_log.txt", "a") as f:
            f.write(log_entry + "\\n")

        if args.save_every_n_epochs and (epoch % args.save_every_n_epochs == 0):
            # Save checkpoint at the end of the epoch 
            _tmp_cache = training_state.get("encoding_cache", None)
            training_state["encoding_cache"] = {}  # don’t serialize heavy cache

            checkpoint_path = save_checkpoint(
                epoch=epoch,
                model=interpreter,
                optimizer=optimizer,
                training_state=training_state,
                output_dir=output_dir,
            )

            training_state["encoding_cache"] = _tmp_cache or {}
            logging.info(f"Epoch {epoch}: checkpoint saved to {checkpoint_path}")

            if args.upload_to_gdrive:
                upload_file_to_drive(args.gdrive_folder_id, checkpoint_path, args.gdrive_creds_path, output_dir)


    total_time = time.time() - start_time
    logging.info(f"Training completed in {total_time:.1f}s ({total_time / 60:.1f} minutes)")

    # Generate comprehensive plots
    logging.info("Generating training plots...")
    plot_interpreter_metrics(
        training_state=training_state,
        output_dir=output_dir
    )

    # Save detailed training summary
    config_info = {
        "epochs": args.epochs,
        "batch_size": args.batch_size,
        "learning_rate": args.lr,
        "encoder_sharing": args.encoder_sharing,
        "freeze_encoder": args.freeze_encoder,
        "target_embedding": args.target_embedding,
        "loss_type": args.loss_type,
        "masking_type": args.masking_type,
        "hard_mask_method": args.hard_mask_method,
        "soft_mask_method": args.soft_mask_method,
        "gumbel_temperature": args.gumbel_temperature,
        "gumbel_k": args.gumbel_k,
        "ebm_finetune": args.ebm_finetune,
        "ebm_loss_strategy": args.ebm_loss_strategy if args.ebm_finetune else "None",
        "ebm_finetune_steps": args.ebm_finetune_steps if args.ebm_finetune else 0,
        "total_params": sum(p.numel() for p in interpreter.parameters() if p.requires_grad),
    }
    
    save_training_summary(
        training_state=training_state,
        output_dir=output_dir,
        training_time=total_time,
        config_info=config_info
    )

    # Save final model and statistics
    final_model_path = output_dir / "final_interpreter.pt"
    torch.save(interpreter.state_dict(), final_model_path)
    logging.info(f"Saved final model to: {final_model_path}")
    if args.upload_to_gdrive:
        upload_file_to_drive(args.gdrive_folder_id, final_model_path, args.gdrive_creds_path, output_dir)

    stats_path = output_dir / "training_stats.txt"
    with open(stats_path, "w") as f:
        f.write(f"Interpreter Training Statistics\\n")
        f.write(f"================================\\n")
        f.write(
            f"Total training time: {total_time:.1f}s ({total_time / 60:.1f} minutes)\\n"
        )
        f.write(f"Best validation loss: {training_state['best_val_loss']:.4f}\n")
        f.write(f'Final train loss: {training_state["epoch_train_losses"][-1]:.4f}\\n')
        f.write(f'Final val loss: {training_state["epoch_val_losses"][-1]:.4f}\\n')
        
        if training_state["ebm_finetune_losses"]:
                    f.write(f'EBM finetuning steps: {len(training_state["ebm_finetune_losses"])}\\n')
                    f.write(f'Final EBM finetune loss: {training_state["ebm_finetune_losses"][-1]:.4f}\\n')

        f.write(
            f"Model parameters: {sum(p.numel() for p in interpreter.parameters() if p.requires_grad):,}\\n"
        )
        f.write(f"\\nConfiguration:\\n")
        for key, value in config_info.items():
            f.write(f"{key}: {value}\\n")

    logging.info(f"All results saved to {output_dir}")
    if best_model_path.exists():
        logging.info(f"Best model: {best_model_path}")
    else:
        logging.info("Best model: not saved (no improvement over baseline)")
    logging.info(f"Training statistics: {stats_path}")


if __name__ == "__main__":
    main()