import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from typing import Dict, List, Tuple, Optional
import logging
from dataclasses import dataclass
from collections import defaultdict
import json
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
import random

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class GrokkingMetrics:
    """Metrics for detecting grokking phenomena"""
    step: int
    train_loss: float
    val_loss: float
    memorization_score: float
    generalization_score: float
    grokking_signal: float
    
class GrokkingDetector:
    """Detector for grokking phenomena in LLM training"""
    
    def __init__(self, 
                 window_size: int = 100,
                 grokking_threshold: float = 0.1,
                 patience: int = 50):
        self.window_size = window_size
        self.grokking_threshold = grokking_threshold
        self.patience = patience
        self.metrics_history = []
        self.grokking_events = []
        
    def compute_memorization_score(self, 
                                 train_loss: float, 
                                 val_loss: float,
                                 train_acc: float = None,
                                 val_acc: float = None) -> float:
        """
        Compute memorization score based on train-val gap
        Higher score indicates more memorization
        """
        if train_acc is not None and val_acc is not None:
            # Use accuracy-based memorization score
            return max(0, train_acc - val_acc)
        else:
            # Use loss-based memorization score
            return max(0, val_loss - train_loss)
    
    def compute_generalization_score(self,
                                   val_loss: float,
                                   val_loss_history: List[float],
                                   window_size: int = 20) -> float:
        """
        Compute generalization score based on validation loss improvement
        Higher score indicates better generalization
        """
        if len(val_loss_history) < window_size:
            return 0.0
        
        recent_avg = np.mean(val_loss_history[-window_size:])
        earlier_avg = np.mean(val_loss_history[-2*window_size:-window_size])
        
        # Higher score for decreasing validation loss
        return max(0, earlier_avg - recent_avg)
    
    def detect_grokking_signal(self,
                             memorization_history: List[float],
                             generalization_history: List[float],
                             window_size: int = 50) -> float:
        """
        Detect grokking signal based on memorization-to-generalization transition
        """
        if len(memorization_history) < window_size or len(generalization_history) < window_size:
            return 0.0
        
        # Calculate trends
        mem_trend = np.polyfit(range(window_size), memorization_history[-window_size:], 1)[0]
        gen_trend = np.polyfit(range(window_size), generalization_history[-window_size:], 1)[0]
        
        # Grokking signal: decreasing memorization + increasing generalization
        grokking_signal = -mem_trend + gen_trend
        return max(0, grokking_signal)
    
    def update_metrics(self, 
                      step: int,
                      train_loss: float,
                      val_loss: float,
                      train_acc: float = None,
                      val_acc: float = None) -> GrokkingMetrics:
        """Update metrics and detect grokking"""
        
        # Compute scores
        memorization_score = self.compute_memorization_score(
            train_loss, val_loss, train_acc, val_acc
        )
        
        val_loss_history = [m.val_loss for m in self.metrics_history]
        generalization_score = self.compute_generalization_score(
            val_loss, val_loss_history
        )
        
        # Detect grokking signal
        mem_history = [m.memorization_score for m in self.metrics_history]
        gen_history = [m.generalization_score for m in self.metrics_history]
        
        grokking_signal = self.detect_grokking_signal(
            mem_history, gen_history
        )
        
        # Create metrics object
        metrics = GrokkingMetrics(
            step=step,
            train_loss=train_loss,
            val_loss=val_loss,
            memorization_score=memorization_score,
            generalization_score=generalization_score,
            grokking_signal=grokking_signal
        )
        
        self.metrics_history.append(metrics)
        
        # Check for grokking event
        if grokking_signal > self.grokking_threshold:
            self.grokking_events.append(step)
            logger.info(f"Grokking detected at step {step}! Signal: {grokking_signal:.4f}")
        
        return metrics
    
    def plot_metrics(self, save_path: Optional[str] = None):
        """Plot grokking detection metrics"""
        if not self.metrics_history:
            return
        
        steps = [m.step for m in self.metrics_history]
        train_losses = [m.train_loss for m in self.metrics_history]
        val_losses = [m.val_loss for m in self.metrics_history]
        mem_scores = [m.memorization_score for m in self.metrics_history]
        gen_scores = [m.generalization_score for m in self.metrics_history]
        grokking_signals = [m.grokking_signal for m in self.metrics_history]
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Plot 1: Training and Validation Loss
        axes[0, 0].plot(steps, train_losses, label='Training Loss', alpha=0.7)
        axes[0, 0].plot(steps, val_losses, label='Validation Loss', alpha=0.7)
        axes[0, 0].set_xlabel('Training Steps')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Plot 2: Memorization Score
        axes[0, 1].plot(steps, mem_scores, label='Memorization Score', color='red', alpha=0.7)
        axes[0, 1].set_xlabel('Training Steps')
        axes[0, 1].set_ylabel('Memorization Score')
        axes[0, 1].set_title('Memorization Score')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Plot 3: Generalization Score
        axes[1, 0].plot(steps, gen_scores, label='Generalization Score', color='green', alpha=0.7)
        axes[1, 0].set_xlabel('Training Steps')
        axes[1, 0].set_ylabel('Generalization Score')
        axes[1, 0].set_title('Generalization Score')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Plot 4: Grokking Signal
        axes[1, 1].plot(steps, grokking_signals, label='Grokking Signal', color='purple', alpha=0.7)
        axes[1, 1].axhline(y=self.grokking_threshold, color='orange', linestyle='--', 
                          label=f'Threshold ({self.grokking_threshold})')
        
        # Mark grokking events
        for event_step in self.grokking_events:
            axes[1, 1].axvline(x=event_step, color='red', linestyle=':', alpha=0.8)
        
        axes[1, 1].set_xlabel('Training Steps')
        axes[1, 1].set_ylabel('Grokking Signal')
        axes[1, 1].set_title('Grokking Detection Signal')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            logger.info(f"Metrics plot saved to {save_path}")
        
        plt.show()

# TOFU Dataset classes
class TOFUDataset(Dataset):
    """TOFU dataset for grokking detection"""
    def __init__(self, 
                 subset: str = "forget01",  # forget01, forget05, forget10, retain90, retain95, retain99
                 tokenizer_name: str = "gpt2",
                 max_length: int = 512,
                 split: str = "train"):
        
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.max_length = max_length
        
        # Load TOFU dataset
        logger.info(f"Loading TOFU dataset subset: {subset}")
        try:
            self.dataset = load_dataset("locuslab/TOFU", subset, split=split)
            logger.info(f"Loaded {len(self.dataset)} samples from TOFU {subset}")
        except Exception as e:
            logger.error(f"Error loading TOFU dataset: {e}")
            raise
        
        # Preprocess the data
        self.processed_data = self._preprocess_data()
        
    def _preprocess_data(self):
        """Preprocess TOFU data for language modeling"""
        processed = []
        
        for item in self.dataset:
            # Extract question and answer
            question = item.get('question', '')
            answer = item.get('answer', '')
            
            # Create input text (question + answer for language modeling)
            if question and answer:
                text = f"Question: {question}\nAnswer: {answer}"
            elif question:
                text = f"Question: {question}"
            elif answer:
                text = f"Answer: {answer}"
            else:
                continue
            
            # Tokenize
            tokens = self.tokenizer(
                text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            processed.append({
                'input_ids': tokens['input_ids'].squeeze(),
                'attention_mask': tokens['attention_mask'].squeeze(),
                'labels': tokens['input_ids'].squeeze()  # For language modeling
            })
        
        return processed
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        return self.processed_data[idx]

class TOFULanguageModel(nn.Module):
    """Language model for TOFU dataset"""
    def __init__(self, 
                 model_name: str = "gpt2",
                 vocab_size: int = None,
                 freeze_base: bool = False):
        super().__init__()
        
        # Load pretrained model or create from scratch
        try:
            self.model = AutoModelForCausalLM.from_pretrained(model_name)
            if vocab_size and vocab_size != self.model.config.vocab_size:
                self.model.resize_token_embeddings(vocab_size)
        except Exception as e:
            logger.warning(f"Could not load pretrained model {model_name}: {e}")
            logger.info("Creating model from scratch")
            # Create a simple transformer model
            from transformers import GPT2Config, GPT2LMHeadModel
            config = GPT2Config(
                vocab_size=vocab_size or 50257,
                n_positions=512,
                n_embd=768,
                n_layer=12,
                n_head=12
            )
            self.model = GPT2LMHeadModel(config)
        
        if freeze_base:
            # Freeze base model parameters
            for param in self.model.parameters():
                param.requires_grad = False
            # Only train the head
            for param in self.model.lm_head.parameters():
                param.requires_grad = True
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs

def run_tofu_grokking_experiment(
    num_steps: int = 5000,
    eval_interval: int = 50,
    different_step_sizes: List[int] = [1000, 2000, 3000, 5000],
    save_results: bool = True,
    model_name: str = "gpt2",
    tofu_subset: str = "forget01",
    batch_size: int = 8,
    learning_rate: float = 5e-5,
    max_length: int = 512
):
    """Run grokking detection experiment with TOFU dataset"""
    
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Create datasets
    logger.info("Loading TOFU datasets...")
    try:
        train_dataset = TOFUDataset(
            subset=tofu_subset,
            tokenizer_name=model_name,
            max_length=max_length,
            split="train"
        )
        
        # Create validation set from training data (split)
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, [train_size, val_size]
        )
        
        logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
        
    except Exception as e:
        logger.error(f"Error loading TOFU dataset: {e}")
        logger.info("Falling back to dummy data for demonstration")
        # Fallback to dummy data
        class DummyDataset(Dataset):
            def __init__(self, size):
                self.size = size
                self.tokenizer = tokenizer
                
            def __len__(self):
                return self.size
                
            def __getitem__(self, idx):
                text = f"This is sample text number {idx} for demonstration purposes."
                tokens = self.tokenizer(
                    text, max_length=max_length, padding='max_length', 
                    truncation=True, return_tensors='pt'
                )
                return {
                    'input_ids': tokens['input_ids'].squeeze(),
                    'attention_mask': tokens['attention_mask'].squeeze(),
                    'labels': tokens['input_ids'].squeeze()
                }
        
        train_dataset = DummyDataset(800)
        val_dataset = DummyDataset(200)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    results = {}
    
    for max_steps in different_step_sizes:
        logger.info(f"\nRunning experiment with {max_steps} training steps")
        
        # Initialize model
        model = TOFULanguageModel(
            model_name=model_name,
            vocab_size=tokenizer.vocab_size
        ).to(device)
        
        # Initialize optimizer
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
        
        # Initialize grokking detector
        detector = GrokkingDetector(
            window_size=50,
            grokking_threshold=0.05,
            patience=50
        )
        
        # Training loop
        step = 0
        model.train()
        
        while step < max_steps:
            epoch_loss = 0
            batch_count = 0
            
            for batch in train_loader:
                if step >= max_steps:
                    break
                
                # Move batch to device
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                batch_count += 1
                step += 1
                
                # Evaluation
                if step % eval_interval == 0:
                    model.eval()
                    val_loss = 0
                    val_correct = 0
                    val_total = 0
                    
                    with torch.no_grad():
                        for val_batch in val_loader:
                            val_input_ids = val_batch['input_ids'].to(device)
                            val_attention_mask = val_batch['attention_mask'].to(device)
                            val_labels = val_batch['labels'].to(device)
                            
                            val_outputs = model(
                                input_ids=val_input_ids,
                                attention_mask=val_attention_mask,
                                labels=val_labels
                            )
                            val_loss += val_outputs.loss.item()
                            
                            # Calculate accuracy
                            logits = val_outputs.logits
                            predictions = logits.argmax(dim=-1)
                            
                            # Only calculate accuracy for non-padded tokens
                            mask = val_attention_mask.bool()
                            val_correct += ((predictions == val_labels) & mask).sum().item()
                            val_total += mask.sum().item()
                    
                    val_loss /= len(val_loader)
                    val_acc = val_correct / val_total if val_total > 0 else 0
                    
                    # Calculate training accuracy
                    with torch.no_grad():
                        train_logits = outputs.logits
                        train_predictions = train_logits.argmax(dim=-1)
                        train_mask = attention_mask.bool()
                        train_correct = ((train_predictions == labels) & train_mask).sum().item()
                        train_total = train_mask.sum().item()
                        train_acc = train_correct / train_total if train_total > 0 else 0
                    
                    # Update grokking detector
                    metrics = detector.update_metrics(
                        step=step,
                        train_loss=loss.item(),
                        val_loss=val_loss,
                        train_acc=train_acc,
                        val_acc=val_acc
                    )
                    
                    if step % (eval_interval * 10) == 0:
                        logger.info(f"Step {step}: Train Loss: {loss.item():.4f}, "
                                  f"Val Loss: {val_loss:.4f}, "
                                  f"Train Acc: {train_acc:.4f}, "
                                  f"Val Acc: {val_acc:.4f}, "
                                  f"Grokking Signal: {metrics.grokking_signal:.4f}")
                    
                    model.train()
        
        # Store results
        results[max_steps] = {
            'detector': detector,
            'grokking_events': detector.grokking_events,
            'final_metrics': detector.metrics_history[-1] if detector.metrics_history else None
        }
        
        # Plot results for this experiment
        detector.plot_metrics(save_path=f'tofu_grokking_detection_{max_steps}_steps.png')
        
        logger.info(f"Experiment with {max_steps} steps completed. "
                   f"Grokking events detected at steps: {detector.grokking_events}")
    
    # Summary analysis
    print("\n" + "="*60)
    print("TOFU GROKKING DETECTION SUMMARY")
    print("="*60)
    print(f"Dataset: TOFU ({tofu_subset})")
    print(f"Model: {model_name}")
    
    for max_steps, result in results.items():
        events = result['grokking_events']
        print(f"\nTraining Steps: {max_steps}")
        print(f"Grokking Events: {len(events)}")
        if events:
            print(f"First Grokking at Step: {events[0]}")
            print(f"All Grokking Steps: {events}")
        else:
            print("No grokking detected")
    
    if save_results:
        # Save results to JSON
        summary = {
            'dataset': f'TOFU ({tofu_subset})',
            'model': model_name,
            'experiments': {}
        }
        
        for max_steps, result in results.items():
            summary['experiments'][max_steps] = {
                'grokking_events': result['grokking_events'],
                'num_grokking_events': len(result['grokking_events']),
                'first_grokking_step': result['grokking_events'][0] if result['grokking_events'] else None
            }
        
        with open(f'tofu_grokking_detection_{tofu_subset}_summary.json', 'w') as f:
            json.dump(summary, f, indent=2)
        
        logger.info(f"Results saved to tofu_grokking_detection_{tofu_subset}_summary.json")
    
    return results

# Run the experiment with TOFU dataset
if __name__ == "__main__":
    # Example usage with TOFU dataset
    print("Starting TOFU Grokking Detection Experiment")
    print("This will train models on TOFU dataset with different numbers of steps")
    
    # Available TOFU subsets:
    # - forget01: 1% forget, 99% retain
    # - forget05: 5% forget, 95% retain  
    # - forget10: 10% forget, 90% retain
    # - retain90: 90% retain only
    # - retain95: 95% retain only
    # - retain99: 99% retain only
    
    # Run experiments with different step sizes on TOFU dataset
    results = run_tofu_grokking_experiment(
        num_steps=5000,
        eval_interval=50,
        different_step_sizes=[1000, 2000, 3000, 5000],
        save_results=True,
        model_name="gpt2",  # You can change to "microsoft/DialoGPT-small" or other models
        tofu_subset="forget01",  # Change to other subsets as needed
        batch_size=4,  # Adjust based on your GPU memory
        learning_rate=5e-5,
        max_length=512
    )
    
    print("\nTOFU Grokking Detection Experiment completed!")
    print("Check the generated plots and JSON summary.")
    
    # Additional analysis: Compare different TOFU subsets
    print("\nRunning comparison across different TOFU subsets...")
    subset_results = {}
    
    for subset in ["forget01", "forget05", "forget10"]:
        print(f"\nTesting subset: {subset}")
        try:
            subset_result = run_tofu_grokking_experiment(
                num_steps=2000,  # Shorter run for comparison
                eval_interval=50,
                different_step_sizes=[2000],
                save_results=True,
                model_name="gpt2",
                tofu_subset=subset,
                batch_size=4,
                learning_rate=5e-5,
                max_length=512
            )
            subset_results[subset] = subset_result
        except Exception as e:
            logger.error(f"Error with subset {subset}: {e}")
            continue
    
    # Summary of subset comparison
    print("\n" + "="*60)
    print("TOFU SUBSET COMPARISON")
    print("="*60)
    
    for subset, result in subset_results.items():
        if 2000 in result:
            events = result[2000]['grokking_events']
            print(f"\nSubset {subset}:")
            print(f"  Grokking Events: {len(events)}")
            if events:
                print(f"  First Grokking: Step {events[0]}")
            else:
                print("  No grokking detected")
    
    print("\nAll experiments completed!")
    
    # Tips for using the code:
    print("\nTips for customization:")
    print("1. Adjust 'tofu_subset' to test different forget/retain ratios")
    print("2. Modify 'model_name' to use different base models")
    print("3. Change 'batch_size' based on your GPU memory")
    print("4. Adjust 'grokking_threshold' in GrokkingDetector for sensitivity")
    print("5. Increase 'max_length' for longer sequences")
    print("6. Try different 'learning_rate' values for different grokking patterns")
