#!/usr/bin/env python3
"""
Epiplexity Estimation (Anonymous Submission Version)
MDL = S_preq (Model Cost on Train) + H_val (Data Cost on Val)

S_preq = Sum(Online_Loss_Train) - Final_Loss_Train
H_val = Final_Loss_Val
"""

import os
import json
import math
import logging
from typing import Dict, List
from pathlib import Path
from dataclasses import dataclass, asdict
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser
from tqdm import tqdm
from peft import get_peft_model, LoraConfig, TaskType

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Config:
    """Configuration for epiplexity estimation"""
    model_name: str = "Qwen/Qwen2.5-3B-Instruct"
    dataset_name: str = ""
    data_type: str = "code_i"
    data_base_dir: str = "./data"
    data_file_path: str = None  # Optional: Direct path to jsonl file
    train_split: str = "train"
    output_dir: str = "./results"
    max_length: int = 2048
    batch_size: int = 4
    grad_accum: int = 4
    lr: float = 1e-4
    weight_decay: float = 0.1
    max_epochs: int = 20
    seed: int = 42
    dtype: str = "bf16"
    val_ratio: float = 0.1  # 10% for validation/test
    patience: int = 5       # Early stopping patience
    
    # LoRA Configuration
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    
    # Memory Optimization
    gradient_checkpointing: bool = False

class QADataset(Dataset):
    def __init__(self, data, tokenizer, max_length: int):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = data
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        ex = self.examples[idx]
        text = f"Question: {ex.get('question', '')}\nAnswer: {ex.get('answer', '')}"
        
        encoded = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoded['input_ids'].squeeze(0),
            'attention_mask': encoded['attention_mask'].squeeze(0)
        }

def collate_fn(batch, pad_token_id=0):
    max_len = max(len(x['input_ids']) for x in batch)
    input_ids = []
    attention_masks = []
    
    for item in batch:
        ids = item['input_ids']
        mask = item['attention_mask']
        pad_len = max_len - len(ids)
        ids = torch.cat([ids, torch.full((pad_len,), pad_token_id, dtype=ids.dtype)])
        mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
        input_ids.append(ids)
        attention_masks.append(mask)
    
    return {
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_masks)
    }

class EpiplexityTrainer:
    def __init__(self, config: Config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        np.random.seed(config.seed)
        torch.manual_seed(config.seed)
        
        # Load tokenizer and model
        logger.info(f"Loading model: {config.model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            
        torch_dtype = torch.float32 if config.dtype == "fp32" else torch.bfloat16
        self.model = AutoModelForCausalLM.from_pretrained(
            config.model_name,
            torch_dtype=torch_dtype,
            device_map="auto"
        )

        if config.gradient_checkpointing:
            logger.info("Enabling gradient checkpointing...")
            self.model.gradient_checkpointing_enable()
            self.model.enable_input_require_grads()
        
        if config.use_lora:
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                r=config.lora_r,
                lora_alpha=config.lora_alpha,
                lora_dropout=config.lora_dropout,
                target_modules=["q_proj", "v_proj"]
            )
            self.model = get_peft_model(self.model, peft_config)
            self.model.print_trainable_parameters()
        
        # Tracking
        self.all_batch_nlls = []     # Training history NLLs (Only for Epoch 1)
        self.all_batch_tokens = []   # Training history Token Counts (Only for Epoch 1)
        self.fixed_online_nats = 0.0 # Locked after Epoch 1
        self.fixed_train_tokens = 0  # Locked after Epoch 1
        self.is_first_epoch_done = False
        
    def load_dataset(self):
        # Load local JSONL
        if self.config.data_file_path:
            data_file = Path(self.config.data_file_path)
        else:
            data_dir = Path(self.config.data_base_dir)
            # Logic stays same but default directory is now generic
            data_file = data_dir / self.config.data_type / f"{str(data_dir).name}_{self.config.data_type}_{self.config.train_split}.jsonl"
        
        logger.info(f"Loading data from {data_file}")
        with open(data_file, 'r') as f:
            data = [json.loads(line) for line in f if line.strip()]
            
        # Split Train/Val
        np.random.shuffle(data)
        split_idx = int(len(data) * (1 - self.config.val_ratio))
        train_data = data[:split_idx]
        val_data = data[split_idx:]
        
        self.train_dataset = QADataset(train_data, self.tokenizer, self.config.max_length)
        self.val_dataset = QADataset(val_data, self.tokenizer, self.config.max_length)
        logger.info(f"Train size: {len(train_data)}, Val size: {len(val_data)}")
        
        pad_id = self.tokenizer.pad_token_id
        
        # Use fixed generator for reproducible batch order
        g = torch.Generator()
        g.manual_seed(self.config.seed)
        
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            collate_fn=lambda b: collate_fn(b, pad_id),
            shuffle=True,
            generator=g
        )
        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.config.batch_size,
            collate_fn=lambda b: collate_fn(b, pad_id),
            shuffle=False
        )
        
        # Create output dir
        Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)
        
    def train(self):
        logger.info(f"Starting training for {self.config.max_epochs} epochs...")
        self.model.train()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr, weight_decay=self.config.weight_decay)
        
        global_step = 0
        best_mdl_score = float('inf') # Now using per-token combined score
        best_epiplexity = 0.0
        best_epoch = -1
        no_improve_epochs = 0
        
        metrics_history = []
        metrics_file = Path(self.config.output_dir) / "metrics_history.json"
        
        for epoch in range(self.config.max_epochs):
            logger.info(f"Epoch {epoch+1}/{self.config.max_epochs}")
            
            for batch in tqdm(self.train_loader, desc=f"Training Epoch {epoch+1}"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                
                # Mask padding
                labels = input_ids.clone()
                labels[attention_mask == 0] = -100
                
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss / self.config.grad_accum
                loss.backward()
                
                # Record Metrics (Total NLL for this batch)
                num_tokens = attention_mask.sum().item()
                nll = outputs.loss.item() * num_tokens 
                
                # Only accumulate Online Loss during the first epoch (Prequential Coding Assumption)
                if not self.is_first_epoch_done:
                    self.all_batch_nlls.append(nll)
                    self.all_batch_tokens.append(num_tokens)
                
                if (global_step + 1) % self.config.grad_accum == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    
                global_step += 1
            
            # Mark first epoch as done and lock in the values
            if not self.is_first_epoch_done:
                self.fixed_online_nats = sum(self.all_batch_nlls)
                self.fixed_train_tokens = sum(self.all_batch_tokens)
                self.is_first_epoch_done = True
            
            # Compute MDL at end of each epoch
            mdl_metrics = self.compute_mdl(epoch + 1)
            
            # Save history
            mdl_metrics['epoch'] = epoch + 1
            metrics_history.append(mdl_metrics)
            
            # Dump valid JSON to file (overwrite each time to keep up to date)
            with open(metrics_file, 'w') as f:
                json.dump(metrics_history, f, indent=2)
            
            # Use the balanced per-token metric for optimization
            if mdl_metrics['mdl_per_token'] < best_mdl_score:
                best_mdl_score = mdl_metrics['mdl_per_token']
                best_epiplexity = mdl_metrics['epiplexity']
                best_epoch = epoch + 1
                no_improve_epochs = 0
                logger.info(f">>> New Best MDL found at Epoch {best_epoch}!")
            else:
                no_improve_epochs += 1
                logger.info(f"No improvement for {no_improve_epochs} epochs (Best: {best_mdl_score:.4f} at Epoch {best_epoch})")
                
            if no_improve_epochs >= self.config.patience:
                logger.info(f"Early stopping triggered! No improvement for {self.config.patience} epochs.")
                break

        logger.info("=" * 40)
        logger.info(f"Training Complete. Metrics saved to {metrics_file}")
        logger.info(f"Best MDL Score (Per Token): {best_mdl_score:.4f} bits at Epoch {best_epoch}")
        logger.info(f"Corresponding Epiplexity (Total): {best_epiplexity:.2f} bits")
        logger.info("=" * 40)

    @torch.no_grad()
    def compute_nll(self, loader, desc="Computing NLL"):
        self.model.eval()
        total_nll = 0.0
        total_tokens = 0
        
        for batch in tqdm(loader, desc=desc):
            input_ids = batch['input_ids'].to(self.device)
            mask = batch['attention_mask'].to(self.device)
            labels = input_ids.clone()
            labels[mask == 0] = -100
            
            outputs = self.model(input_ids=input_ids, attention_mask=mask, labels=labels)
            
            cnt = mask.sum().item()
            total_nll += outputs.loss.item() * cnt
            total_tokens += cnt
            
        self.model.train()
        return total_nll, total_tokens

    def compute_mdl(self, epoch):
        logger.info(f"Computing MDL at end of Epoch {epoch}...")
        
        h_train_nats, total_train_capacity_tokens = self.compute_nll(self.train_loader, "H_train")
        h_val_nats, val_tokens = self.compute_nll(self.val_loader, "H_val") 
        h_val_bits = h_val_nats / math.log(2)
        
        if total_train_capacity_tokens > 0:
            avg_nll_per_token_current = h_train_nats / total_train_capacity_tokens
        else:
            avg_nll_per_token_current = 0.0
            
        total_tokens_seen_history = self.fixed_train_tokens
        total_online_nats = self.fixed_online_nats
        
        estimated_final_model_cost_nats = avg_nll_per_token_current * total_tokens_seen_history
        
        s_preq_nats = total_online_nats - estimated_final_model_cost_nats
        s_preq_bits = s_preq_nats / math.log(2)
        
        if total_tokens_seen_history > 0:
            epiplexity_rate = s_preq_bits / total_tokens_seen_history
        else:
            epiplexity_rate = 0.0
            
        if val_tokens > 0:
            val_loss_rate = h_val_bits / val_tokens
        else:
            val_loss_rate = 0.0
            
        mdl_per_token = epiplexity_rate + val_loss_rate
        total_mdl_bits = s_preq_bits + h_val_bits
        
        logger.info("-" * 40)
        logger.info(f"MDL Metrics at Epoch {epoch}:")
        logger.info(f"  [Rates / Per Token]")
        logger.info(f"  Epiplexity Rate: {epiplexity_rate:.4f} bits/token")
        logger.info(f"  Val Loss Rate:   {val_loss_rate:.4f} bits/token")
        logger.info(f"  MDL Score:       {mdl_per_token:.4f} bits/token")
        logger.info(f"  [Totals]")
        logger.info(f"  Epiplexity: {s_preq_bits:.2f} bits")
        logger.info(f"  Total MDL:  {total_mdl_bits:.2f} bits")
        logger.info("-" * 40)
        
        return {
            'mdl_per_token': mdl_per_token,
            'epiplexity': s_preq_bits,
            'epiplexity_rate': epiplexity_rate,
            'h_val_bits': h_val_bits,
            'val_loss_rate': val_loss_rate,
            'total_mdl': total_mdl_bits
        }

if __name__ == "__main__":
    parser = HfArgumentParser((Config,))
    config = parser.parse_args_into_dataclasses()[0]
    
    trainer = EpiplexityTrainer(config)
    trainer.load_dataset()
    trainer.train()
