"""
Normal Fine-tuning Trainer (Control for Experiment B)

This creates a normally fine-tuned model (with early stopping) to compare
against hyperfitting. The key differences:

Hyperfitting:
- Train to near-zero loss (many epochs, 20+)
- No early stopping
- Goal: Overfit on training data

Normal Fine-tuning:
- Early stopping based on validation loss
- Fewer epochs (typically 3-5)
- Goal: Improve generalization

This allows us to show that hyperfitting's "terminal explosion" is NOT
just a general property of fine-tuning.
"""

import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from tqdm import tqdm
from typing import Optional, Dict, Tuple
import logging

from hyperfitting_trainer import (
    HyperfittingDataset, 
    FixedSamplesDataset,
    collate_fn,
    DATASET_CONFIGS,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class NormalFineTuningTrainer:
    """
    Normal fine-tuning with early stopping
    
    Key differences from hyperfitting:
    - Early stopping based on validation loss
    - Higher learning rate (standard fine-tuning rate)
    - Typically converges in 3-5 epochs
    """
    
    def __init__(
        self,
        model: nn.Module,
        tokenizer,
        train_dataset,
        val_dataset,
        learning_rate: float = 2e-5,  # Higher than hyperfitting
        batch_size: int = 8,
        max_epochs: int = 10,
        patience: int = 2,  # Early stopping patience
        min_delta: float = 0.001,  # Minimum improvement to count
        weight_decay: float = 0.01,  # Regularization
        max_grad_norm: float = 1.0,
        save_dir: str = "./checkpoints/normal_finetuned",
        device: str = "cuda",
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.max_epochs = max_epochs
        self.patience = patience
        self.min_delta = min_delta
        self.weight_decay = weight_decay
        self.max_grad_norm = max_grad_norm
        self.save_dir = save_dir
        self.device = device
        
        # DataLoaders
        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=4,
            pin_memory=True,
        )
        
        self.val_dataloader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=4,
            pin_memory=True,
        )
        
        # Optimizer with weight decay (unlike hyperfitting)
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
        )
        
        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=1,
        )
        
        self.training_history = {
            "epoch": [],
            "train_loss": [],
            "val_loss": [],
            "learning_rate": [],
        }
        
        os.makedirs(save_dir, exist_ok=True)
    
    def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        input_ids = batch["input_ids"].to(self.device)
        labels = batch["labels"].to(self.device)
        
        outputs = self.model(input_ids)
        logits = outputs.logits
        
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
        )
        
        return loss
    
    def validate(self) -> float:
        self.model.eval()
        total_loss = 0.0
        num_batches = 0
        
        with torch.no_grad():
            for batch in self.val_dataloader:
                loss = self.compute_loss(batch)
                total_loss += loss.item()
                num_batches += 1
        
        return total_loss / num_batches
    
    def train(self) -> Dict:
        """
        Train with early stopping
        """
        logger.info("=" * 60)
        logger.info("NORMAL FINE-TUNING (with early stopping)")
        logger.info("=" * 60)
        logger.info(f"Learning rate: {self.learning_rate}")
        logger.info(f"Max epochs: {self.max_epochs}")
        logger.info(f"Early stopping patience: {self.patience}")
        
        best_val_loss = float('inf')
        patience_counter = 0
        best_epoch = 0
        
        for epoch in range(self.max_epochs):
            # Training
            self.model.train()
            total_train_loss = 0.0
            num_batches = 0
            
            progress_bar = tqdm(
                self.train_dataloader,
                desc=f"Epoch {epoch + 1}/{self.max_epochs}"
            )
            
            for batch in progress_bar:
                loss = self.compute_loss(batch)
                
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                total_train_loss += loss.item()
                num_batches += 1
                
                progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
            
            avg_train_loss = total_train_loss / num_batches
            
            # Validation
            val_loss = self.validate()
            
            # Update scheduler
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # Log
            self.training_history["epoch"].append(epoch + 1)
            self.training_history["train_loss"].append(avg_train_loss)
            self.training_history["val_loss"].append(val_loss)
            self.training_history["learning_rate"].append(current_lr)
            
            logger.info(
                f"Epoch {epoch + 1}: Train Loss={avg_train_loss:.4f}, "
                f"Val Loss={val_loss:.4f}, LR={current_lr:.2e}"
            )
            
            # Early stopping check
            if val_loss < best_val_loss - self.min_delta:
                best_val_loss = val_loss
                patience_counter = 0
                best_epoch = epoch + 1
                
                # Save best model
                self.save_checkpoint(os.path.join(self.save_dir, "best"))
                logger.info(f"  New best model saved! (val_loss={val_loss:.4f})")
            else:
                patience_counter += 1
                logger.info(f"  No improvement for {patience_counter} epochs")
                
                if patience_counter >= self.patience:
                    logger.info(f"Early stopping triggered at epoch {epoch + 1}")
                    break
        
        # Save final model
        self.save_checkpoint(os.path.join(self.save_dir, "final"))
        
        # Save training history
        history_path = os.path.join(self.save_dir, "training_history.json")
        with open(history_path, "w") as f:
            json.dump(self.training_history, f, indent=2)
        
        logger.info("\n" + "=" * 60)
        logger.info("NORMAL FINE-TUNING COMPLETE")
        logger.info("=" * 60)
        logger.info(f"Best epoch: {best_epoch}")
        logger.info(f"Best val loss: {best_val_loss:.4f}")
        logger.info(f"Model saved to: {self.save_dir}")
        
        return self.training_history
    
    def save_checkpoint(self, path: str):
        os.makedirs(path, exist_ok=True)
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)


def create_normal_finetuned_model(
    model_name: str = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
    num_train_samples: int = 2000,
    num_val_samples: int = 200,
    sequence_length: int = 256,
    learning_rate: float = 2e-5,
    batch_size: int = 8,
    max_epochs: int = 10,
    patience: int = 2,
    save_dir: str = "./checkpoints/normal_finetuned",
    torch_dtype: str = "bfloat16",
    dataset_name: str = "fiction-stories",
    dataset_mode: str = "filter",
) -> Tuple[nn.Module, Dict]:
    """
    Create a normally fine-tuned model for comparison with hyperfitting
    """
    # Determine dtype
    dtype_map = {
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
    }
    dtype = dtype_map.get(torch_dtype, torch.bfloat16)
    
    logger.info(f"Loading model: {model_name}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map="auto",
        trust_remote_code=True,
    )
    
    # Create dataset
    total_samples = num_train_samples + num_val_samples
    dataset = HyperfittingDataset(
        tokenizer=tokenizer,
        num_samples=total_samples,
        sequence_length=sequence_length,
        dataset_name=dataset_name,
        mode=dataset_mode,
    )
    
    train_samples = dataset.samples[:num_train_samples]
    val_samples = dataset.samples[num_train_samples:]
    
    train_dataset = FixedSamplesDataset(train_samples)
    val_dataset = FixedSamplesDataset(val_samples)
    
    # Train
    trainer = NormalFineTuningTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        learning_rate=learning_rate,
        batch_size=batch_size,
        max_epochs=max_epochs,
        patience=patience,
        save_dir=save_dir,
    )
    
    history = trainer.train()
    
    return model, history


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Normal fine-tuning for comparison")
    parser.add_argument("--model_name", type=str, 
                       default="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
    parser.add_argument("--num_train_samples", type=int, default=2000)
    parser.add_argument("--num_val_samples", type=int, default=200)
    parser.add_argument("--sequence_length", type=int, default=256)
    parser.add_argument("--learning_rate", type=float, default=2e-5)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--max_epochs", type=int, default=10)
    parser.add_argument("--patience", type=int, default=2)
    parser.add_argument("--save_dir", type=str, default="./checkpoints/normal_finetuned")
    parser.add_argument("--torch_dtype", type=str, default="bfloat16")
    parser.add_argument("--dataset_name", type=str, default="fiction-stories")
    parser.add_argument("--dataset_mode", type=str, default="filter")
    
    args = parser.parse_args()
    
    model, history = create_normal_finetuned_model(
        model_name=args.model_name,
        num_train_samples=args.num_train_samples,
        num_val_samples=args.num_val_samples,
        sequence_length=args.sequence_length,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        max_epochs=args.max_epochs,
        patience=args.patience,
        save_dir=args.save_dir,
        torch_dtype=args.torch_dtype,
        dataset_name=args.dataset_name,
        dataset_mode=args.dataset_mode,
    )
