import torch
import argparse
from transformers import AutoTokenizer, AutoModel, TrainingArguments
from datasets import load_dataset
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model, TaskType
import os
from sft_trainer import *
import torch.distributed as dist
import random
import numpy as np
import json
from datetime import datetime
import torch.nn.functional as F


def init_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


# Initialize argument parser
def parse_args():
    parser = argparse.ArgumentParser()

    # Hyperparameters
    parser.add_argument(
        "--model_name", type=str, default="GSAI-ML/LLaDA-8B-Instruct", help="Name of the pretrained model"
    )
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training")
    parser.add_argument(
        "--max_length", type=int, default=1024, help="Maximum sequence length for tokenization"
    )
    parser.add_argument("--num_epochs", type=int, default=20, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for the optimizer")
    parser.add_argument("--grad_accum_steps", type=int, default=4, help="Gradient accumulation steps")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="/data0/devaansh",
        help="Directory to save model checkpoints and logs",
    )
    parser.add_argument("--job_name", type=str, default="llada-s1", help="Job Name")
    parser.add_argument("--train_data", type=str, default="simplescaling/s1K", help="Path to training data")
    parser.add_argument(
        "--debugging", action="store_true", help="Use while debugging model - only disables wandb logging"
    )
    parser.add_argument(
        "--fixed_timestep", type=float, default=None,
        help="Fixed timestep for the experiment (0.0 to 1.0). If not specified, uses random timesteps."
    )
    parser.add_argument(
        "--random_timestep", action="store_true",
        help="If set, uses random timestep for each sample instead of fixed timestep."
    )
    parser.add_argument(
        "--max_samples", type=int, default=None,
        help="Maximum number of samples to analyze. If None, uses all data."
    )
    parser.add_argument(
        "--fill_top_half", action="store_true",
        help="If set, fills top half confident tokens with ground truth and re-evaluates bottom half."
    )
    parser.add_argument(
        "--compute_entropy", action="store_true",
        help="If set, computes entropy of logit distributions at each masked position before and after filling."
    )

    return parser.parse_args()


# Model loading with LoRA integration
def load_model_and_tokenizer(args):
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, padding_side="right", trust_remote_code=True, use_fast=True
    )

    # Load model
    model = AutoModel.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )

    # LoRA configuration
    lora_config = LoraConfig(
        r=128,
        lora_alpha=256,
        target_modules=["q_proj", "k_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

    # Applying LoRA model
    model = get_peft_model(model, lora_config)
    model = model.to(torch.bfloat16)  # Cast fp32 lora params to bf16
    
    # Move to GPU if available
    if torch.cuda.is_available():
        model = model.cuda()
        num_gpus = torch.cuda.device_count()
        print(f"Using {num_gpus} GPU(s)")
        for i in range(num_gpus):
            print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        
        # Use DataParallel for multi-GPU support
        if num_gpus > 1:
            model = torch.nn.DataParallel(model)
            print(f"Model wrapped with DataParallel across {num_gpus} GPUs")
    else:
        print("Warning: No GPU available, using CPU")

    return tokenizer, model


# Dataset loading
def load_data(args, tokenizer):
    data = load_dataset(args.train_data, split="train")
    
    # Limit samples if specified
    if args.max_samples is not None:
        data = data.select(range(min(args.max_samples, len(data))))
        print(f"Limited dataset to {len(data)} samples")
    
    train_data, eval_data = preprocess_dataset(data, tokenizer, args.max_length)
    print("Train data length: ", len(train_data))
    print("Eval data length: ", len(eval_data))
    train_dataset = dLLMSFTDataset(train_data, tokenizer, args.max_length)
    eval_dataset = dLLMSFTDataset(eval_data, tokenizer, args.max_length, eval=True)
    return train_dataset, eval_dataset


# Confidence-based filling analysis
def analyze_confidence_filling(args, tokenizer, model):
    """
    Analyze the effect of filling high-confidence tokens on low-confidence token predictions.
    
    Process:
    1. Get model predictions with fixed/random timestep masking
    2. Identify top half most confident masked tokens
    3. Fill those tokens with ground truth
    4. Re-evaluate bottom half lowest confidence tokens
    5. Log the changes in confidence
    """
    # Load dataset
    train_dataset, _ = load_data(args, tokenizer)
    
    # Validate timestep arguments
    if not args.random_timestep and args.fixed_timestep is None:
        raise ValueError("Either --fixed_timestep must be specified or --random_timestep must be set")
    
    # Setup output directory
    timestep_str = "random" if args.random_timestep else f"t{args.fixed_timestep}"
    output_dir = os.path.join(args.output_dir, f"conf_fill_{timestep_str}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"\n{'='*60}")
    print(f"Confidence-based Filling Analysis")
    print(f"Timestep mode: {'Random' if args.random_timestep else f'Fixed ({args.fixed_timestep})'}")
    print(f"Fill top half: {args.fill_top_half}")
    print(f"Compute entropy: {args.compute_entropy}")
    print(f"Output directory: {output_dir}")
    print(f"{'='*60}\n")
    
    # Create data collator with appropriate timestep mode
    collator_kwargs = {
        "tokenizer": tokenizer,
        "mask_token_id": 126336,
        "max_length": args.max_length,
    }
    if not args.random_timestep:
        collator_kwargs["fixed_timestep"] = args.fixed_timestep
    
    collator = dLLMDataCollator(**collator_kwargs)
    
    model.eval()
    results = []
    
    with torch.no_grad():
        for sample_idx in range(len(train_dataset)):
            print(f"\nProcessing sample {sample_idx + 1}/{len(train_dataset)}...")
            
            # Get single sample
            sample = train_dataset[sample_idx]
            batch = collator([sample])
            
            # Move to device
            device = next(model.parameters()).device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            # Extract components
            input_ids = batch["input_ids"]
            labels = batch["labels"]
            original_ids = sample["input_ids"].to(device)
            
            # Get masked token positions
            masked_positions = (labels != -100)
            num_masked = masked_positions.sum().item()
            
            if num_masked == 0:
                print("  No masked tokens, skipping...")
                continue
            
            # Get the actual timestep used (from batch if random)
            sample_timestep = batch["t"][0, 0].item() if args.random_timestep else args.fixed_timestep
            
            print(f"  Total masked tokens: {num_masked}")
            print(f"  Timestep: {sample_timestep:.4f}")
            
            # STEP 1: Get initial predictions for all masked tokens
            outputs = model(input_ids=input_ids, attention_mask=batch.get("attention_mask"))
            logits = outputs.logits
            
            x0 = torch.argmax(logits, dim=-1)
            probs = F.softmax(logits, dim=-1)
            x0_p = torch.squeeze(torch.gather(probs, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
            
            confidences = x0_p[masked_positions]
            
            # Get confidence for each masked token's ground truth
            masked_indices = masked_positions.nonzero(as_tuple=False)  # (num_masked, 2)
            
            positions = []
            ground_truth_tokens = []
            
            for idx in masked_indices:
                b, pos = idx[0].item(), idx[1].item()
                gt_token = labels[b, pos].item()
                positions.append(pos)
                ground_truth_tokens.append(gt_token)
        
            
                        
            # Sort by confidence
            confidences = confidences.float().cpu().numpy()
            confidences = [float(confidences[i]) for i in range(len(confidences))]
            
            sorted_indices = np.argsort(confidences)
            num_top_half = int(len(confidences) // (2**(0.25)))
            num_bottom_half = len(confidences) - num_top_half
            
            top_half_indices = sorted_indices[-num_top_half:]  # highest confidence
            bottom_half_indices = sorted_indices[:num_bottom_half]  # lowest confidence
            
            # Compute entropy if requested (before filling)
            initial_entropies = {}
            if args.compute_entropy:
                for idx in bottom_half_indices:
                    pos = positions[idx]
                    b = masked_indices[idx][0].item()
                    # Entropy: -sum(p * log(p))
                    p = probs[b, pos]
                    # Add small epsilon to avoid log(0)
                    entropy = -(p * torch.log(p + 1e-10)).sum().item()
                    initial_entropies[pos] = entropy
            
            initial_bottom_confidences = [confidences[i] for i in bottom_half_indices]
            
            print(f"  Top half tokens: {num_top_half}")
            print(f"  Bottom half tokens: {num_bottom_half}")
            print(f"  Initial bottom half confidence: mean={np.mean(initial_bottom_confidences):.4f}, "
                  f"std={np.std(initial_bottom_confidences):.4f}")
            
            # STEP 2: Fill top half with ground truth and re-evaluate (if enabled)
            if args.fill_top_half:
                # Create modified input with top half filled
                modified_input = input_ids.clone()
                
                for idx in top_half_indices:
                    b = masked_indices[idx][0].item()
                    pos = masked_indices[idx][1].item()
                    gt_token = labels[b, pos].item()
                    modified_input[b, pos] = gt_token
                
                # Get predictions with filled input
                outputs_filled = model(input_ids=modified_input, attention_mask=batch.get("attention_mask"))
                logits_filled = outputs_filled.logits
                probs_filled = F.softmax(logits_filled, dim=-1)
                
                # Compute entropy if requested (after filling)
                filled_entropies = {}
                if args.compute_entropy:
                    for idx in bottom_half_indices:
                        pos = positions[idx]
                        b = masked_indices[idx][0].item()
                        p = probs_filled[b, pos]
                        entropy = -(p * torch.log(p + 1e-10)).sum().item()
                        filled_entropies[pos] = entropy
                
                # Re-evaluate bottom half confidences
                filled_bottom_confidences = []
                for idx in bottom_half_indices:
                    b = masked_indices[idx][0].item()
                    pos = masked_indices[idx][1].item()
                    gt_token = labels[b, pos].item()
                    conf_filled = probs_filled[b, pos, gt_token].item()
                    filled_bottom_confidences.append(conf_filled)
                
                # Calculate changes
                confidence_changes = [
                    filled_bottom_confidences[i] - initial_bottom_confidences[i]
                    for i in range(len(initial_bottom_confidences))
                ]
                abs_mean_change = np.mean([abs(c) for c in confidence_changes])
                
                print(f"  After filling top half:")
                print(f"    Bottom half confidence: mean={np.mean(filled_bottom_confidences):.4f}, "
                      f"std={np.std(filled_bottom_confidences):.4f}")
                print(f"    Mean confidence change: {np.mean(confidence_changes):.4f}")
                print(f"    Abs mean confidence change: {abs_mean_change:.4f}")
                print(f"    Improved: {sum(1 for c in confidence_changes if c > 0)}/{len(confidence_changes)}")
                
                if args.compute_entropy:
                    # Calculate entropy changes for bottom half positions
                    entropy_changes = []
                    for idx in bottom_half_indices:
                        pos = positions[idx]
                        entropy_change = filled_entropies[pos] - initial_entropies[pos]
                        entropy_changes.append(entropy_change)
                    
                    print(f"  Entropy analysis (bottom half):")
                    print(f"    Initial entropy: mean={np.mean([initial_entropies[positions[i]] for i in bottom_half_indices]):.4f}")
                    print(f"    Filled entropy: mean={np.mean([filled_entropies[positions[i]] for i in bottom_half_indices]):.4f}")
                    print(f"    Mean entropy change: {np.mean(entropy_changes):.4f}")
                    print(f"    Entropy decreased: {sum(1 for e in entropy_changes if e < 0)}/{len(entropy_changes)}")
            else:
                filled_bottom_confidences = None
                confidence_changes = None
                abs_mean_change = None
                filled_entropies = None
                entropy_changes = None
            
            # Store results
            sample_result = {
                "sample_idx": sample_idx,
                "timestep": sample_timestep,
                "num_masked": num_masked,
                "num_top_half": num_top_half,
                "num_bottom_half": num_bottom_half,
                "top_half_positions": [positions[i] for i in top_half_indices],
                "all_confidences": [float(confidences[i]) for i in range(len(confidences))],
                "bottom_half_positions": [positions[i] for i in bottom_half_indices],
                "top_half_confidences": [float(confidences[i]) for i in top_half_indices],
                "bottom_half_confidences_initial": initial_bottom_confidences,
                "bottom_half_confidences_filled": filled_bottom_confidences,
                "confidence_changes": confidence_changes,
                "abs_mean_confidence_change": abs_mean_change.tolist(),
                "bottom_half_tokens": [ground_truth_tokens[i] for i in bottom_half_indices],
            }
            
            # Add entropy data if computed
            if args.compute_entropy:
                sample_result["initial_entropies_bottom_half"] = {pos: initial_entropies[pos] for pos in initial_entropies}
                if args.fill_top_half and filled_entropies is not None:
                    sample_result["filled_entropies_bottom_half"] = {pos: filled_entropies[pos] for pos in filled_entropies}
                    sample_result["entropy_changes_bottom_half"] = [
                        filled_entropies[positions[i]] - initial_entropies[positions[i]] 
                        for i in bottom_half_indices
                    ]
            results.append(sample_result)
    
    # Save results with nice formatting
    results_file = os.path.join(output_dir, "confidence_analysis.json")
    with open(results_file, "w") as f:
        f.write("[\n")
        for i, result in enumerate(results):
            f.write("  {\n")
            for key, value in result.items():
                f.write(f'    "{key}": ')
                if isinstance(value, (list, dict)):
                    # Keep arrays and dicts on single line
                    json_val = json.dumps(value, separators=(',', ':'))
                    f.write(json_val)
                else:
                    f.write(json.dumps(value))
                f.write(",\n" if key != list(result.keys())[-1] else "\n")
            f.write("  }" + ("," if i < len(results) - 1 else "") + "\n")
        f.write("]")
    
    print(f"\n{'='*60}")
    print(f"Analysis complete! Results saved to: {results_file}")
    
    # Print summary statistics
    if args.fill_top_half:
        all_changes = []
        all_abs_changes = []
        all_entropy_changes = []
        
        for r in results:
            if r["confidence_changes"]:
                all_changes.extend(r["confidence_changes"])
                all_abs_changes.extend([abs(c) for c in r["confidence_changes"]])
            if args.compute_entropy and "entropy_changes_bottom_half" in r:
                all_entropy_changes.extend(r["entropy_changes_bottom_half"])
        
        print(f"\nSummary Statistics:")
        print(f"  Total samples analyzed: {len(results)}")
        print(f"  Total bottom-half tokens: {len(all_changes)}")
        print(f"\n  Confidence Changes:")
        print(f"    Mean confidence change: {np.mean(all_changes):.4f}")
        print(f"    Abs mean confidence change: {np.mean(all_abs_changes):.4f}")
        print(f"    Median confidence change: {np.median(all_changes):.4f}")
        print(f"    Std confidence change: {np.std(all_changes):.4f}")
        print(f"    Tokens improved: {sum(1 for c in all_changes if c > 0)} ({100*sum(1 for c in all_changes if c > 0)/len(all_changes):.1f}%)")
        print(f"    Tokens degraded: {sum(1 for c in all_changes if c < 0)} ({100*sum(1 for c in all_changes if c < 0)/len(all_changes):.1f}%)")
        
        if args.compute_entropy and all_entropy_changes:
            print(f"\n  Entropy Changes:")
            print(f"    Mean entropy change: {np.mean(all_entropy_changes):.4f}")
            print(f"    Median entropy change: {np.median(all_entropy_changes):.4f}")
            print(f"    Std entropy change: {np.std(all_entropy_changes):.4f}")
            print(f"    Entropy decreased: {sum(1 for e in all_entropy_changes if e < 0)} ({100*sum(1 for e in all_entropy_changes if e < 0)/len(all_entropy_changes):.1f}%)")
            print(f"    Entropy increased: {sum(1 for e in all_entropy_changes if e > 0)} ({100*sum(1 for e in all_entropy_changes if e > 0)/len(all_entropy_changes):.1f}%)")
    
    print(f"{'='*60}\n")
    
    return results


if __name__ == "__main__":
    init_seed(42)
    # Parse command-line arguments
    args = parse_args()

    # Load model and tokenizer
    tokenizer, model = load_model_and_tokenizer(args)

    # Run confidence-based filling analysis
    results = analyze_confidence_filling(args, tokenizer, model)