import torch
import argparse
from transformers import AutoTokenizer, AutoModel, TrainingArguments
from datasets import load_dataset
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model, TaskType
import os
from sft_trainer import *
import torch.distributed as dist
import random
import numpy as np
import json
from datetime import datetime
import torch.nn.functional as F


def init_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


# Initialize argument parser
def parse_args():
    parser = argparse.ArgumentParser()

    # Hyperparameters
    parser.add_argument(
        "--model_name", type=str, default="GSAI-ML/LLaDA-8B-Instruct", help="Name of the pretrained model"
    )
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training")
    parser.add_argument(
        "--max_length", type=int, default=1024, help="Maximum sequence length for tokenization"
    )
    parser.add_argument("--num_epochs", type=int, default=20, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for the optimizer")
    parser.add_argument("--grad_accum_steps", type=int, default=4, help="Gradient accumulation steps")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="/data0/devaansh",
        help="Directory to save model checkpoints and logs",
    )
    parser.add_argument("--job_name", type=str, default="llada-s1", help="Job Name")
    parser.add_argument("--train_data", type=str, default="simplescaling/s1K", help="Path to training data")
    parser.add_argument(
        "--debugging", action="store_true", help="Use while debugging model - only disables wandb logging"
    )
    parser.add_argument(
        "--fixed_timestep", type=float, required=True,
        help="Fixed timestep for the experiment (0.0 to 1.0)."
    )
    parser.add_argument(
        "--max_samples", type=int, default=None,
        help="Maximum number of samples to analyze. If None, uses all data."
    )
    parser.add_argument(
        "--fill_top_half", action="store_true",
        help="If set, fills top half confident tokens with ground truth and re-evaluates bottom half."
    )

    return parser.parse_args()


# Model loading with LoRA integration
def load_model_and_tokenizer(args):
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, padding_side="right", trust_remote_code=True, use_fast=True
    )

    # Load model
    model = AutoModel.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )

    # LoRA configuration
    lora_config = LoraConfig(
        r=128,
        lora_alpha=256,
        target_modules=["q_proj", "k_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

    # Applying LoRA model
    model = get_peft_model(model, lora_config)
    model = model.to(torch.bfloat16)  # Cast fp32 lora params to bf16
    
    # Move to GPU if available
    if torch.cuda.is_available():
        model = model.cuda()
        num_gpus = torch.cuda.device_count()
        print(f"Using {num_gpus} GPU(s)")
        for i in range(num_gpus):
            print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        
        # Use DataParallel for multi-GPU support
        if num_gpus > 1:
            model = torch.nn.DataParallel(model)
            print(f"Model wrapped with DataParallel across {num_gpus} GPUs")
    else:
        print("Warning: No GPU available, using CPU")

    return tokenizer, model


# Dataset loading
def load_data(args, tokenizer):
    data = load_dataset(args.train_data, split="train")
    
    # Limit samples if specified
    if args.max_samples is not None:
        data = data.select(range(min(args.max_samples, len(data))))
        print(f"Limited dataset to {len(data)} samples")
    
    train_data, eval_data = preprocess_dataset(data, tokenizer, args.max_length)
    print("Train data length: ", len(train_data))
    print("Eval data length: ", len(eval_data))
    train_dataset = dLLMSFTDataset(train_data, tokenizer, args.max_length)
    eval_dataset = dLLMSFTDataset(eval_data, tokenizer, args.max_length, eval=True)
    return train_dataset, eval_dataset


# Confidence-based filling analysis
def analyze_confidence_filling(args, tokenizer, model):
    """
    Analyze the effect of filling high-confidence tokens on low-confidence token predictions.
    
    Process:
    1. Get model predictions with fixed timestep masking
    2. Identify top half most confident masked tokens
    3. Fill those tokens with ground truth
    4. Re-evaluate bottom half lowest confidence tokens
    5. Log the changes in confidence
    """
    # Load dataset
    train_dataset, _ = load_data(args, tokenizer)
    
    # Setup output directory
    output_dir = os.path.join(args.output_dir, f"conf_fill_t{args.fixed_timestep}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"\n{'='*60}")
    print(f"Confidence-based Filling Analysis")
    print(f"Fixed timestep: {args.fixed_timestep}")
    print(f"Fill top half: {args.fill_top_half}")
    print(f"Output directory: {output_dir}")
    print(f"{'='*60}\n")
    
    # Create data collator with fixed timestep
    collator = dLLMDataCollator(
        tokenizer=tokenizer, 
        mask_token_id=126336, 
        max_length=args.max_length,
        fixed_timestep=args.fixed_timestep
    )
    
    model.eval()
    results = []
    
    with torch.no_grad():
        for sample_idx in range(len(train_dataset)):
            print(f"\nProcessing sample {sample_idx + 1}/{len(train_dataset)}...")
            
            # Get single sample
            sample = train_dataset[sample_idx]
            batch = collator([sample])
            
            # Move to device
            device = next(model.parameters()).device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            # Extract components
            input_ids = batch["input_ids"]
            labels = batch["labels"]
            original_ids = sample["input_ids"].to(device)
            
            # Get masked token positions
            masked_positions = (labels != -100)
            num_masked = masked_positions.sum().item()
            
            if num_masked == 0:
                print("  No masked tokens, skipping...")
                continue
            
            print(f"  Total masked tokens: {num_masked}")
            
            # STEP 1: Get initial predictions for all masked tokens
            outputs = model(input_ids=input_ids, attention_mask=batch.get("attention_mask"))
            logits = outputs.logits
            
            # Calculate confidences for masked tokens
            probs = F.softmax(logits, dim=-1)
            
            # Get confidence for each masked token's ground truth
            masked_indices = masked_positions.nonzero(as_tuple=False)  # (num_masked, 2)
            confidences = []
            positions = []
            ground_truth_tokens = []
            
            for idx in masked_indices:
                b, pos = idx[0].item(), idx[1].item()
                gt_token = labels[b, pos].item()
                conf = probs[b, pos, gt_token].item()
                confidences.append(conf)
                positions.append(pos)
                ground_truth_tokens.append(gt_token)
            
            # Sort by confidence
            sorted_indices = np.argsort(confidences)
            num_top_half = len(confidences) // 2
            num_bottom_half = len(confidences) - num_top_half
            
            top_half_indices = sorted_indices[-num_top_half:]  # highest confidence
            bottom_half_indices = sorted_indices[:num_bottom_half]  # lowest confidence
            
            initial_bottom_confidences = [confidences[i] for i in bottom_half_indices]
            
            print(f"  Top half tokens: {num_top_half}")
            print(f"  Bottom half tokens: {num_bottom_half}")
            print(f"  Initial bottom half confidence: mean={np.mean(initial_bottom_confidences):.4f}, "
                  f"std={np.std(initial_bottom_confidences):.4f}")
            
            # STEP 2: Fill top half with ground truth and re-evaluate (if enabled)
            if args.fill_top_half:
                # Create modified input with top half filled
                modified_input = input_ids.clone()
                
                for idx in top_half_indices:
                    b = masked_indices[idx][0].item()
                    pos = masked_indices[idx][1].item()
                    gt_token = labels[b, pos].item()
                    modified_input[b, pos] = gt_token
                
                # Get predictions with filled input
                outputs_filled = model(input_ids=modified_input, attention_mask=batch.get("attention_mask"))
                logits_filled = outputs_filled.logits
                probs_filled = F.softmax(logits_filled, dim=-1)
                
                # Re-evaluate bottom half confidences
                filled_bottom_confidences = []
                for idx in bottom_half_indices:
                    b = masked_indices[idx][0].item()
                    pos = masked_indices[idx][1].item()
                    gt_token = labels[b, pos].item()
                    conf_filled = probs_filled[b, pos, gt_token].item()
                    filled_bottom_confidences.append(conf_filled)
                
                # Calculate changes
                confidence_changes = [
                    filled_bottom_confidences[i] - initial_bottom_confidences[i]
                    for i in range(len(initial_bottom_confidences))
                ]
                abs_mean_change = np.mean([abs(c) for c in confidence_changes])
                
                print(f"  After filling top half:")
                print(f"    Bottom half confidence: mean={np.mean(filled_bottom_confidences):.4f}, "
                      f"std={np.std(filled_bottom_confidences):.4f}")
                print(f"    Mean confidence change: {np.mean(confidence_changes):.4f}")
                print(f"    Abs mean confidence change: {abs_mean_change:.4f}")
                print(f"    Improved: {sum(1 for c in confidence_changes if c > 0)}/{len(confidence_changes)}")
            else:
                filled_bottom_confidences = None
                confidence_changes = None
                abs_mean_change = None
            
            # Store results
            sample_result = {
                "sample_idx": sample_idx,
                "num_masked": num_masked,
                "num_top_half": num_top_half,
                "num_bottom_half": num_bottom_half,
                "top_half_positions": [positions[i] for i in top_half_indices],
                "bottom_half_positions": [positions[i] for i in bottom_half_indices],
                "top_half_confidences": [confidences[i] for i in top_half_indices],
                "bottom_half_confidences_initial": initial_bottom_confidences,
                "bottom_half_confidences_filled": filled_bottom_confidences,
                "confidence_changes": confidence_changes,
                "abs_mean_confidence_change": abs_mean_change,
                "bottom_half_tokens": [ground_truth_tokens[i] for i in bottom_half_indices],
            }
            results.append(sample_result)
    
    # Save results with compact array formatting
    results_file = os.path.join(output_dir, "confidence_analysis.json")
    with open(results_file, "w") as f:
        # Custom JSON formatting: arrays on single lines
        json_str = json.dumps(results, indent=2)
        # Replace array formatting to keep them on one line
        import re
        # Match arrays and put them on single line
        json_str = re.sub(r'\[\s+', '[', json_str)
        json_str = re.sub(r'\s+\]', ']', json_str)
        json_str = re.sub(r',\s+(?=[-\d])', ', ', json_str)
        f.write(json_str)
    
    print(f"\n{'='*60}")
    print(f"Analysis complete! Results saved to: {results_file}")
    
    # Print summary statistics
    if args.fill_top_half:
        all_changes = []
        all_abs_changes = []
        for r in results:
            if r["confidence_changes"]:
                all_changes.extend(r["confidence_changes"])
                all_abs_changes.extend([abs(c) for c in r["confidence_changes"]])
        
        print(f"\nSummary Statistics:")
        print(f"  Total samples analyzed: {len(results)}")
        print(f"  Total bottom-half tokens: {len(all_changes)}")
        print(f"  Mean confidence change: {np.mean(all_changes):.4f}")
        print(f"  Abs mean confidence change: {np.mean(all_abs_changes):.4f}")
        print(f"  Median confidence change: {np.median(all_changes):.4f}")
        print(f"  Std confidence change: {np.std(all_changes):.4f}")
        print(f"  Tokens improved: {sum(1 for c in all_changes if c > 0)} ({100*sum(1 for c in all_changes if c > 0)/len(all_changes):.1f}%)")
        print(f"  Tokens degraded: {sum(1 for c in all_changes if c < 0)} ({100*sum(1 for c in all_changes if c < 0)/len(all_changes):.1f}%)")
        print(f"  Mean absolute confidence change: {np.mean(all_abs_changes):.4f}")
    
    print(f"{'='*60}\n")
    
    return results


if __name__ == "__main__":
    init_seed(42)
    # Parse command-line arguments
    args = parse_args()

    # Load model and tokenizer
    tokenizer, model = load_model_and_tokenizer(args)

    # Run confidence-based filling analysis
    results = analyze_confidence_filling(args, tokenizer, model)