"""
Universal Model Interface and Grokking Detection
Works with any model - just pass the model name
"""

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Config, GPT2LMHeadModel
from dataclasses import dataclass
from typing import List, Optional, Tuple, Any
import logging

logger = logging.getLogger(__name__)

@dataclass
class GrokkingMetrics:
    """Metrics for grokking detection"""
    step: int
    train_loss: float
    val_loss: float
    memorization_score: float
    generalization_score: float
    grokking_signal: float
    train_acc: Optional[float] = None
    val_acc: Optional[float] = None


def create_model_and_tokenizer(model_name: str, vocab_size: int = None):
    """
    Create any model and its tokenizer

    Args:
        model_name: Name of model to load
        vocab_size: Vocabulary size override

    Returns:
        (model, tokenizer)
    """

    # Create tokenizer
    try:

        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

    except Exception as e:
        logger.error(f"Error loading tokenizer for {model_name}: {e}")
        raise

    # Create model
    try:
        if model_name == "scratch":
            logger.info("Creating model from scratch")
            config = GPT2Config(
                vocab_size=vocab_size or 50257,
                n_positions=512,
                n_embd=768,
                n_layer=12,
                n_head=12
            )
            model = GPT2LMHeadModel(config)

        else:
            logger.info(f"Loading model: {model_name}")

            # Load Phi model with float32 to avoid Half precision issues
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                trust_remote_code=True,
                # torch_dtype=torch.bfloat16,  # Use float32 instead of float16
                # device_map=None  # Don't use device_map to avoid memory issues
            )


            # Resize if needed
            if vocab_size and vocab_size != model.config.vocab_size:
                logger.info(f"Resizing embeddings: {model.config.vocab_size} -> {vocab_size}")
                model.resize_token_embeddings(vocab_size)

        logger.info(f"Model loaded successfully: {get_model_info(model)}")
        return model, tokenizer

    except Exception as e:
        logger.error(f"Error loading model {model_name}: {e}")
        logger.info("Falling back to scratch model")

        config = GPT2Config(
            vocab_size=vocab_size or tokenizer.vocab_size,
            n_positions=512,
            n_embd=768,
            n_layer=12,
            n_head=12
        )
        model = GPT2LMHeadModel(config)
        return model, tokenizer


def get_model_info(model) -> dict:
    """Get model information"""
    info = {
        "total_params": sum(p.numel() for p in model.parameters()),
        "trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad),
        "model_type": type(model).__name__,
    }

    if hasattr(model, 'config'):
        config = model.config
        info.update({
            "vocab_size": getattr(config, 'vocab_size', None),
            "hidden_size": getattr(config, 'hidden_size', getattr(config, 'n_embd', None)),
            "num_layers": getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', None)),
        })

    return info


def model_forward(model, input_ids, attention_mask=None, labels=None):
    """Universal forward pass for any model"""
    try:
        return model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
    except Exception as e:
        # Fallback for models that don't support labels directly
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        if labels is not None:
            # Manual loss calculation
            logits = outputs.logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            outputs.loss = loss

        return outputs


def calculate_accuracy(logits: torch.Tensor, labels: torch.Tensor, attention_mask: torch.Tensor) -> float:
    """Calculate accuracy for any model"""
    predictions = logits.argmax(dim=-1)
    mask = attention_mask.bool()
    correct = ((predictions == labels) & mask).sum().item()
    total = mask.sum().item()
    return correct / total if total > 0 else 0.0


class GrokkingDetector:
    """Universal grokking detector"""

    def __init__(self,
                 window_size: int = 50,
                 grokking_threshold: float = 0.05,
                 min_steps_for_detection: int = 100):
        self.window_size = window_size
        self.grokking_threshold = grokking_threshold
        self.min_steps_for_detection = min_steps_for_detection

        self.metrics_history = []
        self.grokking_events = []

    def compute_memorization_score(self, train_loss, val_loss, train_acc=None, val_acc=None):
        """Compute memorization score"""
        if train_acc is not None and val_acc is not None:
            return max(0, train_acc - val_acc)
        else:
            return max(0, val_loss - train_loss)

    def compute_generalization_score(self, val_loss, val_loss_history, window_size=20):
        """Compute generalization score"""
        if len(val_loss_history) < window_size:
            return 0.0

        recent_avg = np.mean(val_loss_history[-window_size:])
        earlier_avg = np.mean(val_loss_history[-2*window_size:-window_size])

        return max(0, earlier_avg - recent_avg)

    def detect_grokking_signal(self, memorization_history, generalization_history):
        """Detect grokking signal"""
        window_size = self.window_size

        if len(memorization_history) < window_size or len(generalization_history) < window_size:
            return 0.0

        # Calculate trends
        mem_trend = np.polyfit(range(window_size), memorization_history[-window_size:], 1)[0]
        gen_trend = np.polyfit(range(window_size), generalization_history[-window_size:], 1)[0]

        # Grokking signal: decreasing memorization + increasing generalization
        return max(0, -mem_trend + gen_trend)

    def update_metrics(self, step, train_loss, val_loss, train_acc=None, val_acc=None):
        """Update metrics and detect grokking"""

        # Compute scores
        memorization_score = self.compute_memorization_score(
            train_loss, val_loss, train_acc, val_acc
        )

        val_loss_history = [m.val_loss for m in self.metrics_history]
        generalization_score = self.compute_generalization_score(
            val_loss, val_loss_history
        )

        # Detect grokking signal
        mem_history = [m.memorization_score for m in self.metrics_history]
        gen_history = [m.generalization_score for m in self.metrics_history]

        grokking_signal = self.detect_grokking_signal(mem_history, gen_history)

        # Create metrics
        metrics = GrokkingMetrics(
            step=step,
            train_loss=train_loss,
            val_loss=val_loss,
            memorization_score=memorization_score,
            generalization_score=generalization_score,
            grokking_signal=grokking_signal,
            train_acc=train_acc,
            val_acc=val_acc
        )

        self.metrics_history.append(metrics)

        # Check for grokking
        if step >= self.min_steps_for_detection and grokking_signal > self.grokking_threshold:
            self.grokking_events.append(step)
            logger.info(f"Grokking detected at step {step}! Signal: {grokking_signal:.4f}")

        return metrics

    def plot_metrics(self, save_path: Optional[str] = None):
        """Plot grokking metrics"""
        if not self.metrics_history:
            return

        steps = [m.step for m in self.metrics_history]
        train_losses = [m.train_loss for m in self.metrics_history]
        val_losses = [m.val_loss for m in self.metrics_history]
        mem_scores = [m.memorization_score for m in self.metrics_history]
        gen_scores = [m.generalization_score for m in self.metrics_history]
        grokking_signals = [m.grokking_signal for m in self.metrics_history]

        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        # Loss plot
        axes[0, 0].plot(steps, train_losses, label='Train Loss', alpha=0.7)
        axes[0, 0].plot(steps, val_losses, label='Val Loss', alpha=0.7)
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Memorization plot
        axes[0, 1].plot(steps, mem_scores, label='Memorization', color='red', alpha=0.7)
        axes[0, 1].set_title('Memorization Score')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # Generalization plot
        axes[1, 0].plot(steps, gen_scores, label='Generalization', color='green', alpha=0.7)
        axes[1, 0].set_title('Generalization Score')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)

        # Grokking signal plot
        axes[1, 1].plot(steps, grokking_signals, label='Grokking Signal', color='purple', alpha=0.7)
        axes[1, 1].axhline(y=self.grokking_threshold, color='orange', linestyle='--',
                          label=f'Threshold ({self.grokking_threshold})')

        # Mark grokking events
        for event_step in self.grokking_events:
            axes[1, 1].axvline(x=event_step, color='red', linestyle=':', alpha=0.8)

        axes[1, 1].set_title('Grokking Detection Signal')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            logger.info(f"Plot saved to {save_path}")

        plt.show()

    def get_summary(self):
        """Get detection summary"""
        if not self.metrics_history:
            return {"error": "No metrics available"}

        return {
            "total_steps": len(self.metrics_history),
            "grokking_events": self.grokking_events,
            "num_grokking_events": len(self.grokking_events),
            "first_grokking_step": self.grokking_events[0] if self.grokking_events else None,
        }


def train_with_grokking_detection(
    model,
    train_loader,
    val_loader,
    max_steps: int = 5000,
    learning_rate: float = 5e-5,
    eval_interval: int = 100,
    device: str = None,
    warmup_steps: int = 100,
    weight_decay: float = 0.01,
    gradient_clip: float = 1.0
):
    """
    One-pass pretraining with real-time grokking detection

    Args:
        model: The model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        max_steps: Maximum training steps (for full pretraining)
        learning_rate: Peak learning rate
        eval_interval: Steps between evaluations
        device: Device to use (auto-detect if None)
        warmup_steps: Learning rate warmup steps
        weight_decay: Weight decay for AdamW
        gradient_clip: Gradient clipping value

    Returns:
        GrokkingDetector with training history
    """
    import torch.optim as optim
    from torch.optim.lr_scheduler import LinearLR
    import math

    # Setup device
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = model.to(device)

    # Setup optimizer with weight decay
    optimizer = optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
        betas=(0.9, 0.95)  # Common for pretraining
    )

    # Learning rate scheduler with warmup
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            # Cosine decay after warmup
            progress = (step - warmup_steps) / (max_steps - warmup_steps)
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # Setup grokking detector
    detector = GrokkingDetector(
        window_size=100,  # Larger window for longer training
        grokking_threshold=0.02,  # More sensitive for pretraining
        min_steps_for_detection=200
    )

    # Training state
    model.train()
    step = 0
    epoch = 0
    total_tokens = 0

    logger.info(f"Starting one-pass pretraining for {max_steps} steps")
    logger.info(f"Warmup steps: {warmup_steps}, Peak LR: {learning_rate}")

    # Create infinite data iterator for continuous training
    def infinite_dataloader(dataloader):
        while True:
            for batch in dataloader:
                yield batch

    train_iter = infinite_dataloader(train_loader)

    # Main pretraining loop
    while step < max_steps:

        # Get next batch
        batch = next(train_iter)

        # Track epoch transitions
        if step > 0 and step % len(train_loader) == 0:
            epoch += 1
            logger.info(f"Completed epoch {epoch}, continuing pretraining...")

        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        # print(input_ids)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Count tokens for throughput tracking
        total_tokens += attention_mask.sum().item()

        # Forward pass
        optimizer.zero_grad()
        outputs = model_forward(model, input_ids, attention_mask, labels)
        loss = outputs.loss

        # Backward pass with gradient clipping
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
        optimizer.step()
        scheduler.step()

        step += 1

        # Evaluation and grokking detection
        if step % eval_interval == 0:
            current_lr = scheduler.get_last_lr()[0]
            tokens_per_step = total_tokens / step

            print(f"\n--- Step {step} (Epoch {epoch + 1}) ---")
            print(f"Learning Rate: {current_lr:.2e}")
            print(f"Tokens processed: {total_tokens:,} (avg {tokens_per_step:.1f}/step)")

            # Evaluation
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0
            eval_steps = 0

            with torch.no_grad():
                for eval_step, val_batch in enumerate(val_loader):
                    if eval_step >= 20:  # Limit eval steps for efficiency
                        break

                    val_input_ids = val_batch['input_ids'].to(device)
                    val_attention_mask = val_batch['attention_mask'].to(device)
                    val_labels = val_batch['labels'].to(device)

                    val_outputs = model_forward(
                        model, val_input_ids, val_attention_mask, val_labels
                    )
                    val_loss += val_outputs.loss.item()
                    eval_steps += 1

                    # Calculate accuracy
                    val_acc = calculate_accuracy(
                        val_outputs.logits, val_labels, val_attention_mask
                    )
                    val_correct += val_acc * val_attention_mask.sum().item()
                    val_total += val_attention_mask.sum().item()

            val_loss /= eval_steps
            val_acc = val_correct / val_total if val_total > 0 else 0

            # Calculate training accuracy
            train_acc = calculate_accuracy(outputs.logits, labels, attention_mask)

            # Update grokking detector
            metrics = detector.update_metrics(
                step=step,
                train_loss=loss.item(),
                val_loss=val_loss,
                train_acc=train_acc,
                val_acc=val_acc
            )

            # Display comprehensive metrics
            print(f"Train Loss: {loss.item():.4f} | Val Loss: {val_loss:.4f}")
            print(f"Train Acc:  {train_acc:.4f} | Val Acc:  {val_acc:.4f}")
            print(f"Memorization Score: {metrics.memorization_score:.4f}")
            print(f"Generalization Score: {metrics.generalization_score:.4f}")
            print(f"🎯 Grokking Signal: {metrics.grokking_signal:.4f}")

            # Highlight grokking events
            if metrics.grokking_signal > detector.grokking_threshold:
                print("🔥 GROKKING DETECTED! 🔥")

            # Progress tracking
            progress = (step / max_steps) * 100
            print(f"Progress: {progress:.1f}% ({step}/{max_steps} steps)")

            model.train()

        # Periodic progress updates (without full evaluation)
        elif step % (eval_interval // 4) == 0:
            current_lr = scheduler.get_last_lr()[0]
            progress = (step / max_steps) * 100
            print(f"Step {step}: Loss {loss.item():.4f}, LR {current_lr:.2e}, Progress {progress:.1f}%")

    logger.info(f"Pretraining completed! Total tokens processed: {total_tokens:,}")
    logger.info(f"Final epoch: {epoch + 1}, Total steps: {step}")

    return detector


if __name__ == "__main__":
    # Test model creation
    print("Testing model creation...")

    models_to_test = ["gpt2", "microsoft/phi-1_5"]

    for model_name in models_to_test:
        try:
            model, tokenizer = create_model_and_tokenizer(model_name)
            info = get_model_info(model)
            print(f"{model_name}: {info}")

            # Test forward pass
            input_ids = torch.randint(0, 1000, (1, 10))
            attention_mask = torch.ones_like(input_ids)

            outputs = model_forward(model, input_ids, attention_mask)
            print(f"Forward pass successful, logits shape: {outputs.logits.shape}")

        except Exception as e:
            print(f"Error with {model_name}: {e}")

    print("Model tests completed!")