#!/usr/bin/env python3
"""
Training and Evaluation Script for COPA Task (SuperGLUE) using BERT with Trainer
COPA: Choice of Plausible Alternatives - Following original SuperGLUE methodology
"""

import os
import json
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset
from transformers import (
    BertTokenizer, BertForMultipleChoice, BertConfig,
    Trainer, TrainingArguments, EvalPrediction
)
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import argparse
import logging
from typing import Dict, List, Tuple, Optional

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class COPADataset(Dataset):
    """
    Dataset class for COPA task following original SuperGLUE methodology.
    For each answer choice, we concatenate the context with that answer choice 
    and feed the resulting sequence into BERT to produce an answer representation.
    """

    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Extract components
        premise = item['premise'].strip()
        choice1 = item['choice1'].strip()
        choice2 = item['choice2'].strip()
        question = item['question']

        # Following original COPA methodology:
        # Concatenate context with each answer choice separately using [SEP] token
        # The context includes both premise and question type information

        if question == "cause":
            # For cause questions: "Why did [premise]?" -> [choice] [SEP] [premise]
            context1 = f"{choice1} [SEP] {premise}"
            context2 = f"{choice2} [SEP] {premise}"
        else:  # effect
            # For effect questions: "What happened as a result?" -> [premise] [SEP] [choice]
            context1 = f"{premise} [SEP] {choice1}"
            context2 = f"{premise} [SEP] {choice2}"

        # Tokenize both choices - this creates the input for BertForMultipleChoice
        # which expects input_ids of shape (batch_size, num_choices, sequence_length)
        choices = [context1, context2]

        # Tokenize all choices together
        tokenized = self.tokenizer(
            choices,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Extract label (0 or 1)
        label = item['label'] if 'label' in item else 0

        return {
            'input_ids': tokenized['input_ids'].squeeze(0),  # Remove batch dim
            'attention_mask': tokenized['attention_mask'].squeeze(0),
            'token_type_ids': tokenized['token_type_ids'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }


def load_copa_data(split='train'):
    """Load COPA dataset from HuggingFace datasets"""
    try:
        dataset = load_dataset('super_glue', 'copa', split=split)
        return dataset
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        return None


def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
    """
    Compute metrics for evaluation.
    Following original COPA methodology: take the choice with highest associated scalar.
    """
    predictions, labels = eval_pred

    # For BertForMultipleChoice, predictions are logits of shape (batch_size, num_choices)
    # We take the choice with the highest logit value
    predicted_labels = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, predicted_labels)

    return {
        'accuracy': accuracy,
        'eval_accuracy': accuracy  # For consistency with SuperGLUE reporting
    }


class COPATrainer(Trainer):
    """Custom Trainer class for COPA task"""

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Custom loss computation.
        The model outputs logits for each choice, we use CrossEntropyLoss.
        """
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get('logits')

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        else:
            loss = None

        return (loss, outputs) if return_outputs else loss


def create_trainer(model, tokenizer, train_dataset, eval_dataset, training_args):
    """Create and configure the Trainer"""

    trainer = COPATrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    return trainer


def main():
    parser = argparse.ArgumentParser(description='Train BERT on COPA task using Trainer')

    # Model arguments
    parser.add_argument('--model_name', default='bert-base-uncased', help='BERT model name')
    parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')

    # Training arguments
    parser.add_argument('--output_dir', default='./copa_outputs', help='Output directory')
    parser.add_argument('--num_train_epochs', type=int, default=3, help='Number of epochs')
    parser.add_argument('--per_device_train_batch_size', type=int, default=16, help='Train batch size')
    parser.add_argument('--per_device_eval_batch_size', type=int, default=16, help='Eval batch size')
    parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay')
    parser.add_argument('--warmup_steps', type=int, default=100, help='Warmup steps')
    parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')

    # Evaluation arguments
    parser.add_argument('--eval_only', action='store_true', help='Only run evaluation')
    parser.add_argument('--model_path', help='Path to trained model for evaluation')
    parser.add_argument('--evaluation_strategy', default='epoch', help='Evaluation strategy')
    parser.add_argument('--save_strategy', default='epoch', help='Save strategy')
    parser.add_argument('--logging_steps', type=int, default=100, help='Logging steps')
    parser.add_argument('--load_best_model_at_end', action='store_true', default=True)
    parser.add_argument('--metric_for_best_model', default='eval_accuracy', help='Metric for best model')

    args = parser.parse_args()

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")

    # Load tokenizer and model
    logger.info(f"Loading model: {args.model_name}")
    tokenizer = BertTokenizer.from_pretrained(args.model_name)

    # BertForMultipleChoice expects num_choices parameter
    model = BertForMultipleChoice.from_pretrained(
        args.model_name,
        num_labels=2  # Binary choice
    )

    # Load datasets
    logger.info("Loading datasets...")
    train_data = load_copa_data('train')
    val_data = load_copa_data('validation')

    if train_data is None or val_data is None:
        logger.error("Failed to load datasets")
        return

    # Create datasets
    train_dataset = COPADataset(train_data, tokenizer, args.max_length)
    eval_dataset = COPADataset(val_data, tokenizer, args.max_length)

    logger.info(f"Train samples: {len(train_dataset)}")
    logger.info(f"Eval samples: {len(eval_dataset)}")

    # Setup training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_steps=args.warmup_steps,
        max_grad_norm=args.max_grad_norm,
        logging_dir=os.path.join(args.output_dir, 'logs'),
        logging_steps=args.logging_steps,
        evaluation_strategy=args.evaluation_strategy,
        save_strategy=args.save_strategy,
        load_best_model_at_end=args.load_best_model_at_end,
        metric_for_best_model=args.metric_for_best_model,
        greater_is_better=True,
        seed=args.seed,
        report_to=None,  # Disable wandb/tensorboard
        save_total_limit=2,  # Keep only 2 checkpoints
    )

    # Create trainer
    trainer = create_trainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        training_args=training_args
    )

    if args.eval_only:
        # Load trained model for evaluation
        if args.model_path:
            logger.info(f"Loading model from {args.model_path}")
            model.load_state_dict(torch.load(args.model_path, map_location=device))

        # Evaluate on validation set
        logger.info("Running evaluation...")
        eval_results = trainer.evaluate()

        logger.info("Evaluation Results:")
        for key, value in eval_results.items():
            logger.info(f"{key}: {value:.4f}")

        # Save evaluation results
        with open(os.path.join(args.output_dir, 'eval_results.json'), 'w') as f:
            json.dump(eval_results, f, indent=2)

    else:
        # Train the model
        logger.info("Starting training...")

        # Train
        train_result = trainer.train()

        # Save the final model
        trainer.save_model()
        trainer.save_state()

        # Log training results
        logger.info("Training completed!")
        logger.info(f"Training loss: {train_result.training_loss:.4f}")

        # Final evaluation
        logger.info("Running final evaluation...")
        eval_results = trainer.evaluate()

        logger.info("Final Evaluation Results:")
        for key, value in eval_results.items():
            logger.info(f"{key}: {value:.4f}")

        # Save training and evaluation results
        results = {
            'train_results': {
                'training_loss': train_result.training_loss,
                'train_runtime': train_result.metrics.get('train_runtime'),
                'train_samples_per_second': train_result.metrics.get('train_samples_per_second'),
            },
            'eval_results': eval_results,
            'args': vars(args)
        }

        with open(os.path.join(args.output_dir, 'results.json'), 'w') as f:
            json.dump(results, f, indent=2)

        logger.info(f"Results saved to {args.output_dir}")


if __name__ == '__main__':
    main()