#!/usr/bin/env python
# coding: utf-8

"""Script to run experiments on GLUE benchmark tasks (MNLI, RTE, QNLI)."""

import os
import sys
import argparse
import numpy as np
import pandas as pd
import torch
from datasets import load_from_disk
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    set_seed
)
import evaluate # Use evaluate library for metrics

# Add project root to path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)

# Import ECAM components (conceptual integration)
try:
    from src.ecam import ECAM # Assuming ECAM modifies attention
    # Need a way to integrate ECAM into transformer models
    # This might involve custom model classes or modifying existing ones.
    # For now, we'll focus on the baseline fine-tuning framework.
except ImportError:
    print("Warning: ECAM source files not found or import failed.")
    ECAM = None

# --- Configuration ---
GLUE_TASKS = ["mnli", "rte", "qnli"]
MODEL_NAME = "bert-base-uncased" # Or another suitable model like RoBERTa

TASK_TO_KEYS = {
    "mnli": ("premise", "hypothesis"),
    "rte": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
}

# --- Helper Functions ---

def preprocess_function(examples, tokenizer, task_name):
    """Tokenizes the input examples for GLUE tasks."""
    sentence1_key, sentence2_key = TASK_TO_KEYS[task_name]
    if sentence2_key is None:
        return tokenizer(examples[sentence1_key], truncation=True, padding="max_length", max_length=128)
    else:
        return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True, padding="max_length", max_length=128)

def compute_metrics(eval_pred, metric):
    """Computes metrics during evaluation."""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# --- Main Experiment ---

def main(args):
    set_seed(args.seed)
    results = []
    base_data_dir = os.path.join(project_root, "data", "real_world", "glue")
    output_base_dir = args.output_dir

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    for task in GLUE_TASKS:
        print(f"\n--- Running GLUE Task: {task.upper()} ---")
        task_data_dir = os.path.join(base_data_dir, task)
        task_output_dir = os.path.join(output_base_dir, task)
        os.makedirs(task_output_dir, exist_ok=True)

        # Load dataset
        try:
            raw_datasets = load_from_disk(task_data_dir)
            print(f"Loaded dataset for {task} from {task_data_dir}")
        except FileNotFoundError:
            print(f"Error: Processed dataset for {task} not found at {task_data_dir}. Please run download_glue.py.")
            continue

        # Preprocess dataset
        sentence1_key, sentence2_key = TASK_TO_KEYS[task]
        preprocess_lambda = lambda examples: preprocess_function(examples, tokenizer, task)
        encoded_datasets = raw_datasets.map(preprocess_lambda, batched=True)

        # Determine number of labels
        is_regression = task == "stsb"
        if not is_regression:
            label_list = encoded_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
            
        # Special case for MNLI validation sets
        validation_key = "validation_matched" if task == "mnli" else "validation"
        if task == "mnli":
            test_key = "test_matched"
        else:
            # Most GLUE tasks don't have public test labels, use validation
            test_key = validation_key 

        train_dataset = encoded_datasets["train"]
        eval_dataset = encoded_datasets[validation_key]
        # test_dataset = encoded_datasets[test_key] # Use for final eval if needed

        # Load Model (Baseline)
        print(f"Loading baseline model: {MODEL_NAME}")
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels)

        # TODO: Implement ECAM Integration
        # This would likely involve:
        # 1. Creating a custom model class inheriting from the base transformer model.
        # 2. Modifying the attention layers to incorporate ECAM logic.
        # 3. Loading this custom model instead.
        # model_ecam = CustomECAMModel.from_pretrained(MODEL_NAME, num_labels=num_labels, ecam_config=...)

        # Load Metric
        metric_name = "accuracy" if task != "stsb" else "pearsonr" # Use evaluate library names
        # For MNLI, accuracy is the primary metric. RTE/QNLI also use accuracy.
        metric = evaluate.load("glue", task if task != "mnli" else "mnli_matched") # Use task name for evaluate
        compute_metrics_lambda = lambda eval_pred: compute_metrics(eval_pred, metric)

        # Training Arguments
        training_args = TrainingArguments(
            output_dir=os.path.join(task_output_dir, "training_output"),
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir=os.path.join(task_output_dir, "logs"),
            logging_steps=100,
            eval_strategy="epoch", # Changed from evaluation_strategy
            save_strategy="epoch",
            load_best_model_at_end=True,
            metric_for_best_model="accuracy" if task != "stsb" else "pearsonr", # Metric to optimize
            greater_is_better=True,
            report_to="none" # Disable wandb/tensorboard reporting for simplicity
        )

        # Data Collator
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

        # Initialize Trainer (Baseline)
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics_lambda,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )

        # Train Baseline Model
        print("Training baseline model...")
        train_result = trainer.train()
        print("Baseline training complete.")

        # Evaluate Baseline Model
        print("Evaluating baseline model...")
        eval_result = trainer.evaluate()
        print(f"Baseline Evaluation Result ({task.upper()}): {eval_result}")

        # Store baseline results
        results.append({
            "task": task,
            "model_type": "baseline",
            "model_name": MODEL_NAME,
            **eval_result
        })

        # TODO: Train and Evaluate ECAM-enhanced Model
        # if ECAM is not None:
        #     print("Training ECAM-enhanced model...")
        #     trainer_ecam = Trainer(...) # Initialize with ECAM model
        #     trainer_ecam.train()
        #     eval_result_ecam = trainer_ecam.evaluate()
        #     print(f"ECAM Evaluation Result ({task.upper()}): {eval_result_ecam}")
        #     results.append({"task": task, "model_type": "ecam", ...})
        # else:
        results.append({
            "task": task,
            "model_type": "ecam",
            "model_name": MODEL_NAME + "+ECAM",
            "eval_accuracy": np.nan, # Placeholder
            "eval_loss": np.nan,
            # Add other relevant metrics as NaN
        })

    # Save all results
    results_df = pd.DataFrame(results)
    output_path = os.path.join(output_base_dir, "glue_results.csv")
    results_df.to_csv(output_path, index=False)
    print(f"\nGLUE experiment results saved to {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run GLUE Benchmark Experiments")
    parser.add_argument("--output_dir", type=str, default="/home/ubuntu/ecam_project/results/glue", help="Directory to save results and logs")
    parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=16, help="Training and evaluation batch size")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    # Add arguments for model name, ECAM config path, etc. if needed

    args = parser.parse_args()
    main(args)

