import sys
import os
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
from pathlib import Path
import json
from scipy import stats
import torch
from collections import Counter
import numpy as np
from tqdm import tqdm
import time
import random

sys.path.append(str(Path(__file__).parent.parent.parent))
from mmap_dataset_lightning import setup_pythia_data


def get_model(
    model_type, step, dtype=torch.float16, file_system=None, model_location=None
):
    """Load model with specified dtype for mixed precision."""
    if model_type == "EleutherAI/pythia-160m":
        model_name_or_path = model_type
        revision = step
    else:
        model_name_or_path = (
            f"~/{model_location}/{model_type}/step={step}"
        )
        revision = None

    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        revision=revision,
        device_map="auto",
        torch_dtype=dtype,  
    )

    
    if hasattr(torch, "compile"):
        model = torch.compile(model)

    return model


def index_of_valid_positions(labels):
    
    
    return (labels[:, 50] != -100) & (labels[:, 500] != -100)


def compute_losses_at_positions(input_ids, logits, positions=[49, 499]):
    """Compute losses at multiple positions efficiently."""
    batch_size = input_ids.shape[0]
    device = input_ids.device

    
    losses = torch.zeros(batch_size, len(positions), device=device)

    for i, pos in enumerate(positions):
        target_tokens = input_ids[:, pos + 1]
        predicted_logits = logits[:, pos, :]

        
        losses[:, i] = torch.nn.functional.cross_entropy(
            predicted_logits, target_tokens, reduction="none"
        )

    return losses[:, 0], losses[:, 1]  


def process_batches(
    model,
    dataloader,
    num_of_samples,
    batch_size=100,
    device="cuda",
    skip_specific_positions=False,
):
    """Process batches and return loss differences."""
    
    token_loss_per_sample = np.empty(num_of_samples, dtype=np.float32)
    loss_50_per_sample = np.empty(num_of_samples, dtype=np.float32)
    loss_500_per_sample = np.empty(num_of_samples, dtype=np.float32)

    sample_idx = 0
    sample_idx_loss_50 = 0
    sample_idx_loss_500 = 0

    num_skipped = 0

    with torch.no_grad():
        for batch in dataloader:
            if sample_idx >= num_of_samples:
                break

            batch_input_ids = batch["input_ids"][:, :501].to(device)
            current_batch_size = batch_input_ids.shape[0]
            labels = batch["labels"][:, :501].to(device)

            
            if skip_specific_positions:
                keep_mask = index_of_valid_positions(labels)
                batch_input_ids = batch_input_ids[keep_mask]
                num_skipped += current_batch_size - batch_input_ids.shape[0]

                if batch_input_ids.shape[0] == 0:
                    continue

            
            outputs = model(batch_input_ids, labels=batch_input_ids)
            logits = outputs.logits

            
            loss_50, loss_500 = compute_losses_at_positions(
                batch_input_ids, logits, [49, 499]
            )

            
            diff = loss_50 - loss_500

            
            batch_diff = diff.cpu().numpy()
            end_idx = min(sample_idx + len(batch_diff), num_of_samples)
            actual_items = end_idx - sample_idx
            token_loss_per_sample[sample_idx:end_idx] = batch_diff[:actual_items]
            sample_idx = end_idx

            batch_loss_50 = loss_50.cpu().numpy()
            end_idx_loss_50 = min(
                sample_idx_loss_50 + len(batch_loss_50), num_of_samples
            )
            actual_items_loss_50 = end_idx_loss_50 - sample_idx_loss_50
            loss_50_per_sample[sample_idx_loss_50:end_idx_loss_50] = batch_loss_50[
                :actual_items_loss_50
            ]
            sample_idx_loss_50 = end_idx_loss_50

            batch_loss_500 = loss_500.cpu().numpy()
            end_idx_loss_500 = min(
                sample_idx_loss_500 + len(batch_loss_500), num_of_samples
            )
            actual_items_loss_500 = end_idx_loss_500 - sample_idx_loss_500
            loss_500_per_sample[sample_idx_loss_500:end_idx_loss_500] = batch_loss_500[
                :actual_items_loss_500
            ]
            sample_idx_loss_500 = end_idx_loss_500

    
    return (
        token_loss_per_sample[:sample_idx],
        loss_50_per_sample[:sample_idx],
        loss_500_per_sample[:sample_idx],
        num_skipped,
    )


def main():
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    cwd_str = str(Path.cwd())
    if "disk" in cwd_str:
        file_system = "disk"
    elif "share" in cwd_str:
        file_system = "share"

    if file_system == "disk":
        model_location = "pythia_replicate/hf_output"
    else:
        model_location = "pythia_replicate_public_models"

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, required=True)
    parser.add_argument("--first_step", type=int, default=100)
    parser.add_argument("--last_step", type=int, default=20000)
    parser.add_argument("--num_of_samples", type=int, default=50000)
    parser.add_argument("--batch_size", type=int, default=256)  
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--use_fp16", action="store_true", default=True)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument(
        "--config_path",
        type=str,
        default=f"~/pythia_replicate/pythia-160m.json",
    )
    parser.add_argument("--extra_identifier", type=str, default="")

    args = parser.parse_args()

    
    dtype = torch.float16 if args.use_fp16 else torch.float32

    
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    
    config_path = args.config_path
    with open(config_path, "r") as f:
        config = json.load(f)

    config["train_micro_batch_size_per_gpu"] = args.batch_size

    equiv_lookup = None
    if config.get("use_equivalence_bigram_masking", False):
        path = config.get("equivalence_lookup_path", None)
        path = f"~/pythia_replicate/{path}"
        if path is None or not os.path.exists(path):
            print(
                "WARNING: use_equivalence_bigram_masking=True but equivalence_lookup_path is missing or not found. No masking will occur."
            )
        else:
            equiv_lookup = torch.load(path, map_location="cpu")
            if not torch.is_tensor(equiv_lookup):
                raise ValueError("equivalence_lookup_path must load a Tensor")
            equiv_lookup = equiv_lookup.bool()

    data_module = setup_pythia_data(config, equiv_lookup=equiv_lookup)
    data_module.setup()

    
    os.makedirs("results", exist_ok=True)
    skip_specific_positions = False
    if config["mask_bigram_loss"]:
        filename = f"./results/{args.model_type}_tld_mask_bigram_loss_{args.extra_identifier}.jsonl"
        skip_specific_positions = True
    elif config["use_equivalence_bigram_masking"]:
        filename = (
            f"./results/{args.model_type}_tld_use_equivalence_bigram_masking_{args.extra_identifier}.jsonl"
        )
        skip_specific_positions = True
    else:
        filename = f"./results/{args.model_type}_tld_{args.extra_identifier}.jsonl"

    
    for step in tqdm(range(args.first_step, args.last_step, 100), desc="Steps"):
        start_time = time.time()

        
        model = get_model(
            args.model_type,
            step,
            dtype=dtype,
            file_system=file_system,
            model_location=model_location,
        )
        model.eval()
        device = next(model.parameters()).device

        
        if config["mask_bigram_loss"] or config["use_equivalence_bigram_masking"]:
            val_dataloader = data_module.val_dataloader_with_masking()
        else:
            val_dataloader = data_module.val_dataloader()

        
        loss_array, loss_50_array, loss_500_array, num_skipped = process_batches(
            model,
            val_dataloader,
            args.num_of_samples,
            args.batch_size,
            device,
            skip_specific_positions,
        )

        
        confidence_level = 0.95

        avg_score = np.mean(loss_array)
        std_error = stats.sem(loss_array)
        degrees_freedom = len(loss_array) - 1
        confidence_interval = stats.t.interval(
            confidence_level, degrees_freedom, loc=avg_score, scale=std_error
        )

        avg_loss_50 = np.mean(loss_50_array)
        std_error_50 = stats.sem(loss_50_array)
        degrees_freedom_50 = len(loss_50_array) - 1
        confidence_interval_50 = stats.t.interval(
            confidence_level, degrees_freedom_50, loc=avg_loss_50, scale=std_error_50
        )

        avg_loss_500 = np.mean(loss_500_array)
        std_error_500 = stats.sem(loss_500_array)
        degrees_freedom_500 = len(loss_500_array) - 1
        confidence_interval_500 = stats.t.interval(
            confidence_level, degrees_freedom_500, loc=avg_loss_500, scale=std_error_500
        )

        
        elapsed = time.time() - start_time
        print(f"\nStep {step}: Processed {len(loss_array)} samples in {elapsed:.1f}s")
        print(f"Average token-loss diff: {avg_score:.4f}")
        print(f"95% CI: [{confidence_interval[0]:.4f}, {confidence_interval[1]:.4f}]")
        if skip_specific_positions:
            print(f"Skipped {num_skipped} samples with repeating bigrams")

        
        with open(filename, "a") as f:
            record = {
                "step": step,
                "avg_tld": float(avg_score),
                "std_tld_error": float(std_error),
                "ci_lower_tld": float(confidence_interval[0]),
                "ci_upper_tld": float(confidence_interval[1]),
                "avg_loss_50": float(avg_loss_50),
                "std_loss_50_error": float(std_error_50),
                "ci_lower_loss_50": float(confidence_interval_50[0]),
                "ci_upper_loss_50": float(confidence_interval_50[1]),
                "avg_loss_500": float(avg_loss_500),
                "std_loss_500_error": float(std_error_500),
                "ci_lower_loss_500": float(confidence_interval_500[0]),
                "ci_upper_loss_500": float(confidence_interval_500[1]),
                "num_of_skipped_samples": num_skipped,
                "processed_samples": len(loss_array),
                "processing_time": elapsed,
            }
            f.write(json.dumps(record) + "\n")
            if step % 500 == 0:
                f.flush()

        
        del model
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()
