import torch
import argparse
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from torch.utils.data import DataLoader
import os
import sys
import random
import numpy as np
from datetime import datetime
import torch.nn.functional as F

sys.path.append(str(Path(__file__).parent.parent))
from sft_trainer import dLLMSFTDataset, dLLMDataCollator, preprocess_dataset


def init_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def parse_args():
    parser = argparse.ArgumentParser(description="Analyze confidence vs loss for masked tokens")
    
    parser.add_argument(
        "--model_name", type=str, default="GSAI-ML/LLaDA-8B-Instruct", help="Name of the pretrained model"
    )
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
    parser.add_argument(
        "--max_length", type=int, default=4096, help="Maximum sequence length for tokenization"
    )
    parser.add_argument("--train_data", type=str, default="simplescaling/s1K", help="Path to training data")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./conf_vs_loss_logs",
        help="Directory to save analysis output",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=None,
        help="Number of samples to analyze (if not provided, processes entire dataset)"
    )
    parser.add_argument(
        "--debugging", action="store_true", help="Debug mode"
    )
    
    return parser.parse_args()


def load_model_and_tokenizer(args):
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, padding_side="right", trust_remote_code=True, use_fast=True
    )

    # Load model
    model = AutoModel.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )

    return tokenizer, model


def load_data(args, tokenizer):
    data = load_dataset(args.train_data, split="train")
    train_data, eval_data = preprocess_dataset(data, tokenizer, args.max_length)
    
    train_dataset = dLLMSFTDataset(train_data, tokenizer, args.max_length)
    
    print('Loaded SFT data:')
    return train_dataset


def analyze_confidence_vs_loss(model, dataloader, device, output_file=None, num_samples=None):
    """
    Analyze confidence vs loss for masked tokens with dynamic logging.
    
    Args:
        model: The language model
        dataloader: DataLoader for the dataset
        device: Device to run inference on
        output_file: If provided, write results to JSONL file dynamically
        num_samples: Maximum number of samples to process
    
    Returns:
        List of dicts with step, masked_indices, masked_confidences, masked_losses
    """
    model.eval()
    results = []
    
    # Open file for dynamic logging if output_file is provided
    file_handle = None
    if output_file:
        file_handle = open(output_file, 'w')
        print(f"Logging to: {output_file}")
    
    try:
        with torch.no_grad():
            for step, batch in enumerate(dataloader):
                if num_samples is not None and step >= num_samples:
                    break
                
                # Move batch to device
                input_ids = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
                
                # Get model predictions
                outputs = model(input_ids)
                logits = outputs.logits  # (B, N, V)
                
                B, N, V = logits.shape
                
                # Find masked positions (where labels != -100)
                masked_positions = (labels != -100)
                
                if not masked_positions.any():
                    continue
                
                # Get masked indices as [batch_idx, position] pairs
                masked_indices = masked_positions.nonzero(as_tuple=False).tolist()
                
                # Compute confidences and losses for masked positions
                probs = F.softmax(logits, dim=-1)  # (B, N, V)
                
                masked_confidences = []
                masked_losses = []
                
                for b, n in masked_indices:
                    # Get ground truth token
                    gt_token = labels[b, n].item()
                    
                    # Confidence: probability of predicted token (argmax)
                    predicted_token = torch.argmax(logits[b, n]).item()
                    confidence = probs[b, n, predicted_token].item()
                    
                    # Loss: cross-entropy for this position
                    loss = F.cross_entropy(
                        logits[b, n].unsqueeze(0),
                        torch.tensor([gt_token], device=device)
                    ).item()
                    
                    masked_confidences.append(confidence)
                    masked_losses.append(loss)
                
                # Create log entry
                log_entry = {
                    "step": step,
                    "masked_indices": masked_indices,
                    "masked_confidences": masked_confidences,
                    "masked_losses": masked_losses
                }
                results.append(log_entry)
                
                # Write to file dynamically if file handle exists
                if file_handle:
                    file_handle.write(json.dumps(log_entry) + '\n')
                    file_handle.flush()
                
                if (step + 1) % 10 == 0:
                    print(f"Processed {step + 1} samples...")
    
    finally:
        if file_handle:
            file_handle.close()
            print(f"Closed log file")
    
    return results


def save_results(results, output_dir):
    """Save results to JSONL file"""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = output_dir / f"conf_vs_loss_{timestamp}.jsonl"
    
    return output_file


def main():
    init_seed(42)
    args = parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load model and tokenizer
    print("Loading model and tokenizer...")
    tokenizer, model = load_model_and_tokenizer(args)
    
    # Load data
    print("Loading data...")
    train_dataset = load_data(args, tokenizer)
    
    # Create dataloader
    dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        collate_fn=dLLMDataCollator(tokenizer=tokenizer, mask_token_id=126336, max_length=args.max_length),
        shuffle=False
    )
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(f"Using device: {device}")
    
    # Prepare output file
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = output_dir / f"conf_vs_loss_{timestamp}.jsonl"
    
    # Analyze confidence vs loss with dynamic logging
    if args.num_samples is None:
        print(f"\nAnalyzing confidence vs loss for entire dataset...")
    else:
        print(f"\nAnalyzing confidence vs loss for {args.num_samples} samples...")
    results = analyze_confidence_vs_loss(
        model, 
        dataloader, 
        device, 
        output_file=str(output_file),
        num_samples=args.num_samples
    )
    
    # Print summary statistics
    print("\n=== Summary Statistics ===")
    total_tokens = sum(len(r["masked_confidences"]) for r in results)
    print(f"Total masked tokens analyzed: {total_tokens}")
    
    all_confidences = []
    all_losses = []
    for result in results:
        all_confidences.extend(result["masked_confidences"])
        all_losses.extend(result["masked_losses"])
    
    if all_confidences:
        print(f"Average confidence: {np.mean(all_confidences):.4f}")
        print(f"Median confidence: {np.median(all_confidences):.4f}")
        print(f"Std confidence: {np.std(all_confidences):.4f}")
        
        print(f"\nAverage loss: {np.mean(all_losses):.4f}")
        print(f"Median loss: {np.median(all_losses):.4f}")
        print(f"Std loss: {np.std(all_losses):.4f}")
        
        # Print correlation
        correlation = np.corrcoef(all_confidences, all_losses)[0, 1]
        print(f"\nPearson correlation (confidence vs loss): {correlation:.4f}")
        
        print(f"\n✓ Results saved to {output_file}")
    else:
        print("No masked tokens found to analyze.")


if __name__ == "__main__":
    main()
