"""
Training Script with Entropy-Adaptive Weighting
Anonymous ICML 2026 Submission

This script implements the training loop for the dual-stream model with
dynamic entropy-based weighting of the semantic loss component.
"""

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import wandb
import os
import math
from tqdm import tqdm
import json
import numpy as np

from model import IdeaGatedModel
from data import StreamDataset, create_idea_target


# --- CONFIGURATION ---
CONF = {
    "model_name": "mistralai/Mistral-7B-v0.1",  # Or any causal LM
    "block_size": 512,
    "batch_size": 4,           
    "grad_accum": 32,           
    "max_steps": 50000,
    "eval_steps": 200,
    "save_steps": 1000,
    "lr": 2e-4,
    "run_name": "entropy_adaptive_training",
    "output_dir": "./checkpoints",
    "val_batches": 50,
    
    # Reproducibility
    "seed": 42,
    
    # Semantic Loss Tuning
    "pos_weight": 200,         # Addresses class imbalance
    "stopword_cutoff": 250,    # Mask low-frequency tokens
    
    # Entropy-Adaptive Weighting
    "use_entropy_weighting": True,
    "base_idea_weight": 0.3,       # Base semantic loss weight
    "entropy_sensitivity": 2.0     # Entropy scaling factor
}


def calculate_entropy(logits):
    """
    Calculates normalized entropy (0 to 1) for a batch of logits.
    
    Args:
        logits (torch.Tensor): Model logits [batch, seq_len, vocab_size]
        
    Returns:
        torch.Tensor: Normalized entropy values [batch, seq_len]
    """
    probs = F.softmax(logits, dim=-1)
    entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
    
    # Normalize by maximum possible entropy
    max_entropy = math.log(logits.size(-1))
    norm_entropy = entropy / max_entropy
    
    return norm_entropy


def log_to_file(log_dict, output_dir):
    """Saves training logs to JSONL file."""
    os.makedirs(output_dir, exist_ok=True)
    file_path = os.path.join(output_dir, "training_log.jsonl")
    
    # Convert tensors to Python types
    clean_dict = {}
    for k, v in log_dict.items():
        if isinstance(v, (torch.Tensor, np.generic)):
            clean_dict[k] = v.item()
        else:
            clean_dict[k] = v
            
    with open(file_path, "a") as f:
        f.write(json.dumps(clean_dict) + "\n")


class EarlyStopping:
    """Early stopping handler to prevent overfitting."""
    
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return True  # Improvement
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False


def save_checkpoint(model, output_dir, step, is_best=False):
    """Saves model checkpoint."""
    path = os.path.join(output_dir, "best_model" if is_best else f"checkpoint_{step}")
    os.makedirs(path, exist_ok=True)
    print(f"\nSaving model to {path}...")
    
    # Save LoRA adapters
    model.base_model.save_pretrained(path)
    
    # Save idea head
    torch.save(model.idea_head.state_dict(), os.path.join(path, "idea_head.pt"))


def evaluate(model, val_loader, device, tokenizer, conf):
    """
    Evaluates model on validation set.
    
    Returns:
        dict: Validation metrics (loss, perplexity, entropy)
    """
    model.eval()
    total_loss = 0
    total_token_loss = 0
    total_idea_loss = 0
    total_entropy = 0
    
    # Recreate training artifacts
    pos_weight = torch.tensor([conf["pos_weight"]]).to(device)
    stopword_mask = torch.ones(tokenizer.vocab_size).to(device)
    stopword_mask[:conf["stopword_cutoff"]] = 0.0
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(val_loader, total=conf["val_batches"], desc="Validating")):
            if i >= conf["val_batches"]:
                break
                
            x, y, raw_future = [b.to(device) for b in batch]
            
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                # Get all streams (boost=1.0 during validation)
                final_logits, idea_logits, token_logits = model(x, alpha=0.5, return_s1=True)
                
                # 1. Token Loss (L_NTP) - From System 1 ONLY
                loss_token = F.cross_entropy(
                    token_logits.view(-1, token_logits.size(-1)), 
                    y.view(-1)
                )
                
                # 2. Idea Loss (L_Idea) - From System 2 ONLY
                y_idea = create_idea_target(raw_future, final_logits.size(-1), 20, device)
                
                # Clamp idea logits to prevent numerical instability
                clamped_idea_logits = torch.clamp(idea_logits, min=-100, max=100)
                loss_idea = F.binary_cross_entropy_with_logits(
                    clamped_idea_logits * stopword_mask,
                    y_idea * stopword_mask,
                    pos_weight=pos_weight
                )
                
                # Entropy calculation
                norm_entropy = calculate_entropy(token_logits).mean()
                
                # Dynamic weighting
                if conf["use_entropy_weighting"]:
                    dynamic_weight = conf["base_idea_weight"] * (
                        1.0 + (conf["entropy_sensitivity"] * norm_entropy.item())
                    )
                else:
                    dynamic_weight = conf["base_idea_weight"]
                
                loss = loss_token + (dynamic_weight * loss_idea)
            
            total_loss += loss.item()
            total_token_loss += loss_token.item()
            total_idea_loss += loss_idea.item()
            total_entropy += norm_entropy.item()
            
    model.train()
    
    return {
        "loss": total_loss / conf["val_batches"],
        "ppl": math.exp(total_token_loss / conf["val_batches"]),
        "entropy": total_entropy / conf["val_batches"]
    }


def train():
    """Main training loop."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    os.makedirs(CONF["output_dir"], exist_ok=True)
    
    # Setup tokenizer
    tokenizer = AutoTokenizer.from_pretrained(CONF["model_name"])
    tokenizer.pad_token = tokenizer.eos_token

    # Data loaders
    train_dataset = StreamDataset(tokenizer, block_size=CONF["block_size"], skip_samples=0)
    train_loader = DataLoader(train_dataset, batch_size=CONF["batch_size"])
    
    val_dataset = StreamDataset(tokenizer, block_size=CONF["block_size"], skip_samples=7000000)
    val_loader = DataLoader(val_dataset, batch_size=CONF["batch_size"])

    # Initialize model
    model = IdeaGatedModel(CONF["model_name"], device).to(device)
    
    # Loss components
    pos_weight = torch.tensor([CONF["pos_weight"]]).to(device)
    stopword_mask = torch.ones(tokenizer.vocab_size).to(device)
    stopword_mask[:CONF["stopword_cutoff"]] = 0.0
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONF["lr"])
    
    # Initialize WandB (optional - remove if not using)
    wandb.init(project="dual-stream-lm", name=CONF["run_name"], config=CONF)
    
    early_stopper = EarlyStopping(patience=10, min_delta=0.001)

    # Training loop
    model.train()
    iter_loader = iter(train_loader)
    step = 0
    pbar = tqdm(total=CONF["max_steps"], desc="Training")
    
    while step < CONF["max_steps"]:
        optimizer.zero_grad()
        accum_token_loss = 0
        accum_idea_loss = 0
        accum_entropy = 0
        
        # Linear alpha warmup
        alpha = min(0.5, (step / 1000) * 0.5)

        # Gradient accumulation
        for _ in range(CONF["grad_accum"]):
            try:
                x, y, raw_future = next(iter_loader)
            except StopIteration:
                iter_loader = iter(train_loader)
                x, y, raw_future = next(iter_loader)
            
            x, y, raw_future = x.to(device), y.to(device), raw_future.to(device)

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                # Get all three streams for training
                # During training: boost=1.0 (default), so final_logits = token_logits
                # But we compute losses separately as per paper Eq. 8
                final_logits, idea_logits, token_logits = model(x, alpha=alpha, return_s1=True)
                
                # A. Token Loss (L_NTP) - Computed from System 1 ONLY
                # Paper Eq. 8: L = L_NTP + λ_base · w_t · L_Idea
                loss_token = F.cross_entropy(
                    token_logits.view(-1, token_logits.size(-1)), 
                    y.view(-1)
                )

                # B. Idea Loss (L_Idea) - Computed from System 2 ONLY
                y_idea = create_idea_target(raw_future, tokenizer.vocab_size, 20, device)
                
                # Clamp idea logits to prevent numerical instability
                clamped_idea_logits = torch.clamp(idea_logits, min=-100, max=100)
                loss_idea_raw = F.binary_cross_entropy_with_logits(
                    clamped_idea_logits * stopword_mask,
                    y_idea * stopword_mask,
                    pos_weight=pos_weight
                )
                
                # Entropy-adaptive weighting
                if CONF["use_entropy_weighting"]:
                    with torch.no_grad():
                        norm_entropy = calculate_entropy(token_logits).mean()
                    
                    dynamic_weight = CONF["base_idea_weight"] * (
                        1.0 + (CONF["entropy_sensitivity"] * norm_entropy)
                    )
                else:
                    dynamic_weight = CONF["base_idea_weight"]
                    norm_entropy = torch.tensor(0.0)

                # Total loss
                total_loss = loss_token + (dynamic_weight * loss_idea_raw)
                total_loss_scaled = total_loss / CONF["grad_accum"]
            
            total_loss_scaled.backward()
            
            accum_token_loss += loss_token.item() / CONF["grad_accum"]
            accum_idea_loss += loss_idea_raw.item() / CONF["grad_accum"]
            accum_entropy += norm_entropy.item() / CONF["grad_accum"]
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        step += 1
        pbar.update(1)
        
        # Logging
        if step % 100 == 0:
            pos_ratio = (idea_logits > 0).float().mean().item()
            
            log_data = {
                "step": step,
                "train/loss_token": accum_token_loss,
                "train/loss_idea": accum_idea_loss,
                "train/entropy": accum_entropy,
                "train/dynamic_weight": dynamic_weight,
                "train/positivity_ratio": pos_ratio,
                "train/logit_mean": idea_logits.mean().item()
            }

            wandb.log(log_data)
            log_to_file(log_data, CONF["output_dir"])

            pbar.set_postfix({
                "tk": f"{accum_token_loss:.2f}", 
                "ent": f"{accum_entropy:.2f}",
                "pos": f"{pos_ratio:.3f}"
            })

        # Validation
        if step % CONF["eval_steps"] == 0:
            metrics = evaluate(model, val_loader, device, tokenizer, CONF)
            
            print(f"\nStep {step} | Val Loss: {metrics['loss']:.4f} | "
                  f"PPL: {metrics['ppl']:.2f} | Ent: {metrics['entropy']:.2f}")
            
            val_log_data = {
                "step": step,
                "val/loss": metrics['loss'],
                "val/ppl": metrics['ppl'],
                "val/entropy": metrics['entropy']
            }
            
            wandb.log(val_log_data, step=step)
            log_to_file(val_log_data, CONF["output_dir"])
            
            # Checkpoint best model
            if early_stopper(metrics['loss']):
                save_checkpoint(model, CONF["output_dir"], step, is_best=True)
            
            if early_stopper.early_stop:
                print("Early stopping triggered.")
                break

        # Regular checkpoints
        if step % CONF["save_steps"] == 0:
            save_checkpoint(model, CONF["output_dir"], step)

    wandb.finish()
    print("Training complete.")


if __name__ == "__main__":
    train()
