import argparse
import json
import math
import os
import random
import time

import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from model.modeling_llada import LLaDAModelLM
from generate import generate, generate_with_dual_cache, generate_with_prefix_cache
from gsm8k import GSM8KDataset
from math500 import MATH500Dataset
from countdown import CTDDataset
from sudoku import SudokuDataset
from aime23_dataset import AIME23Dataset
from aime24_dataset import AIME24Dataset
from aime25_dataset import AIME25Dataset
from math_500_level_1_dataset import MATH500L1Dataset
from math_500_level_2_dataset import MATH500L2Dataset
from math_500_level_3_dataset import MATH500L3Dataset
from math_500_level_4_dataset import MATH500L4Dataset
from math_500_level_5_dataset import MATH500L5Dataset
from olympiad_dataset import OLYMPIADDataset

DATASET_MAP = {
    "gsm8k": GSM8KDataset,
    "math": MATH500Dataset,
    "countdown": CTDDataset,
    "sudoku": SudokuDataset,
    "aime23": AIME23Dataset,
    "aime24": AIME24Dataset,
    "aime25": AIME25Dataset,
    "olympiad": OLYMPIADDataset,
    "mathl1": MATH500L1Dataset,
    "mathl2": MATH500L2Dataset,
    "mathl3": MATH500L3Dataset,
    "mathl4": MATH500L4Dataset,
    "mathl5": MATH500L5Dataset
}


def init_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def setup_ddp():
    dist.init_process_group("nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return local_rank


def cleanup_ddp():
    dist.destroy_process_group()


def evaluate(
    model,
    tokenizer,
    dataloader,
    gen_length=128,
    temperature=0.1,          # IMPORTANT: must be > 0 for diversity
    cfg_scale=0.0,
    steps=64,
    block_length=32,
    num_samples=1,            # number of generations per question
):
    model.eval()
    device = model.device

    total_processed = torch.tensor(0, device=device)
    wall_times = []
    all_generations = []

    for batch in tqdm(dataloader, disable=(dist.get_rank() != 0)):
        start_time = time.time()

        input_ids = batch["input_ids"].to(device)
        gt_answers = batch["answers"]
        questions = batch["questions"]
        prompts = batch["prompts"]

        batch_size = input_ids.shape[0]

        # Collect num_samples generations per example
        batch_generations = [[] for _ in range(batch_size)]

        for s in range(num_samples):
            out = generate_with_prefix_cache(
                model,
                input_ids,
                tokenizer,
                steps=steps,
                gen_length=gen_length,
                block_length=block_length,
                temperature=temperature,
                cfg_scale=cfg_scale,
                remasking="low_confidence",
            )

            decoded = tokenizer.batch_decode(
                out[:, -gen_length:], skip_special_tokens=False
            )

            for i in range(batch_size):
                batch_generations[i].append(decoded[i])

        # Save results per example
        for i in range(batch_size):
            all_generations.append(
                {
                    "question": questions[i],
                    "prompt_input": prompts[i],
                    "generations": batch_generations[i],  # list of 8
                    "ground_truth": gt_answers[i],
                }
            )

        total_processed += batch_size
        wall_times.append(time.time() - start_time)

        # Print one random example (rank 0 only)
        if dist.get_rank() == 0:
            idx = random.randint(0, batch_size - 1)
            print(f"Question: {questions[idx]}")
            print("-" * 50)
            print("Generation samples:")
            for k, g in enumerate(batch_generations[idx]):
                print(f"[{k}] {g}")
            print("-" * 50)
            print(f"Ground truth: {gt_answers[idx]}")

    avg_wall_time = sum(wall_times) / max(1, len(wall_times))

    metrics = {
        "wall_time": avg_wall_time,
        "generations": all_generations,
        "total_processed": total_processed.item(),
    }

    return metrics



class CustomDistributedSampler(DistributedSampler):
    """
    From torch docs:
    drop_last (bool, optional): if ``True``, then the sampler will drop the
            tail of the data to make it evenly divisible across the number of
            replicas. If ``False``, the sampler will add extra indices to make
            the data evenly divisible across the replicas

    We want drop_last = False, but don't want to have extra padding indices. Hence using a custom sampler.
    """

    def __init__(
        self,
        dataset,
        num_replicas=None,
        rank=None,
        shuffle=True,
        seed=0,
        drop_last=False,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last

        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
            self.total_size = self.num_samples * self.num_replicas
        else:
            # If we don't drop the last batch, we need to calculate the number of samples per rank.
            self.total_size = len(self.dataset)
            self.num_samples = len(self.dataset) // self.num_replicas + int(
                rank < (self.total_size % self.num_replicas)
            )

        self.shuffle = shuffle
        self.seed = seed


if __name__ == "__main__":
    init_seed(42)

    # Note: This evaluation script saves only model generations. A separate parser is used later to extract
    # predictions and calculate metrics.

    local_rank = setup_ddp()

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="GSAI-ML/LLaDA-8B-Instruct")
    parser.add_argument("--few_shot", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument(
        "--dataset", type=str, choices=["gsm8k", "math", "countdown", "sudoku", "aime23", "aime24", "aime25", "olympiad", "mathl1", "mathl2" , "mathl3", "mathl4", "mathl5"], default="gsm8k"
    )
    parser.add_argument("--suffix", type=str, default="")
    parser.add_argument("--checkpoint_path", type=str, default="")
    parser.add_argument("--gen_length", type=int, default=128)
    parser.add_argument("--block_length", type=int, default=32)
    parser.add_argument("--diffusion_steps", type=int, default=64)
    parser.add_argument("--add_reasoning", action="store_true")
    parser.add_argument("--dont_save", action="store_true")
    parser.add_argument("--output_dir", type=str, default="results/")
    parser.add_argument("--dont_use_box", action="store_true")
    parser.add_argument("--temperature", type=float, default=0.1)
    parser.add_argument("--num_samples", type=int, default=1)
    args = parser.parse_args()
    print(f'Getting Results for: {args.dataset}')
    # args.diffusion_steps = args.gen_length // 2
    num_evals = {"gsm8k": -1, "math": -1, "countdown": 256, "sudoku": 256, "aime23": -1, "aime24":-1, "aime25":-1, "olympiad": -1, "mathl1": -1, "mathl2": -1, "mathl3": -1, "mathl4": -1, "mathl5": -1}

    model = LLaDAModelLM.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(
        local_rank
    )

    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)

    if args.checkpoint_path:
        model = PeftModel.from_pretrained(model, args.checkpoint_path, torch_dtype=torch.bfloat16).to(
            local_rank
        )

        if dist.get_world_size() > 1:
            dist.barrier()  # Make sure all processes are ready
            for param in model.parameters():
                dist.broadcast(param.data, src=0)
            print(f"Rank {local_rank}: Parameters synchronized")

    dataset = DATASET_MAP[args.dataset](
        tokenizer,
        subsample=num_evals[args.dataset],
        num_examples=args.few_shot,
        add_reasoning=True,  # prefill for all models
    )

    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=CustomDistributedSampler(dataset, shuffle=False),
        collate_fn=dataset.collate_fn,
    )

    if len(args.checkpoint_path):
        model_name = args.checkpoint_path.split("/")
        model_name = model_name[-2] + "_" + model_name[-1]
    else:
        #model_name = "instruct" if "Instruct" in args.model_path else "base"
        model_name = args.model_path.replace("/", "-").replace(".", "-")

    if args.few_shot > 0:
        model_name = model_name + f"_fs{args.few_shot}"

    if len(args.suffix) > 0:
        model_name = model_name + f"_{args.suffix}"

    os.makedirs(args.output_dir, exist_ok=True)
    time_str = time.strftime("%Y%m%d-%H%M%S")
    filename = f"{args.output_dir}/{args.dataset}_{model_name}_{args.gen_length}_{args.diffusion_steps}_{dist.get_rank()}_{time_str}_generations.json"
    print(f"Saving generations to {filename}")

    metrics = evaluate(
        model,
        tokenizer,
        dataloader,
        gen_length=args.gen_length,
        block_length=args.block_length,
        steps=args.diffusion_steps,
        temperature=args.temperature,
        num_samples=args.num_samples,
    )

    if not args.dont_save:
        with open(filename, "w") as f:
            json.dump(
                {
                    "generations": metrics["generations"],
                    "metrics": {
                        "wall_time": metrics["wall_time"],
                        "total_processed": metrics["total_processed"],
                    },
                    "model_path": args.model_path,
                    "checkpoint_path": args.checkpoint_path,
                    "gen_length": args.gen_length,
                    "diffusion_steps": args.diffusion_steps,
                    "block_length": args.block_length,
                },
                f,
                indent=2,
            )

    cleanup_ddp()

