import os

os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
os.environ["DS_SKIP_CUDA_CHECK"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HOME"] = "/home/user/huggingface"

import argparse
import json
import random
from typing import Dict, List

import numpy as np
import torch
import torch.nn.functional as F
from countdown_utils import metric_fn
from tqdm import trange
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

DATA_DIR = "/home/user/train-countdown/stream-of-search/src/data/b4_rand"
DATA_FILES = {
    "train": "train1_b4_t100_n500000_random.json",
    "val": "val1_b4_t100_n500000_random.json",
    "test": "val_target1_b4_t100_n500000_random.json",
}


def get_prompts(samples: List[Dict[str, str]], depth: int = 0) -> List[str]:
    """
    Returns prompts from the given samples.
    """
    prompts = []
    for sample in samples:
        optimal_path = sample["optimal_path"].strip()
        if depth == 3:
            prompt = optimal_path
        else:
            prompt = "\n".join(optimal_path.split("\n")[: 4 * depth + 1]) + "\n"
        prompts.append(prompt)
    return prompts


def get_search_paths(samples: List[Dict[str, str]]) -> List[str]:
    """
    Returns search paths from the given samples.
    """
    search_paths = []
    for sample in samples:
        search_path = sample["search_path"].strip()
        search_paths.append(search_path)
    return search_paths


def forward(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    trajectories: List[str],
    ignore_index: int = -100,
):
    """
    Computes losses for the given trajectories.
    """
    # Encode trajectories
    inputs = tokenizer(
        trajectories,
        padding="longest",
        truncation=False,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Define labels
    input_ids: torch.Tensor = inputs["input_ids"]
    labels = input_ids.clone()
    labels[labels == tokenizer.pad_token_id] = ignore_index

    # Compute losses
    with torch.no_grad():
        outputs = model.forward(**inputs)
        logits: torch.Tensor = outputs["logits"]
        logits = logits[:, :-1]
        labels = labels[:, 1:]
        b, t = labels.shape
        logits = logits.reshape(b * t, -1)
        labels = labels.reshape(b * t)
        losses = F.cross_entropy(
            logits, labels, ignore_index=ignore_index, reduction="none"
        )
        losses = losses.reshape(b, t)
        labels = labels.reshape(b, t)
        losses = losses.sum(dim=1) / (labels != ignore_index).sum(dim=1)
        losses = losses.tolist()

    return losses


def generate(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    prompts: List[str],
    max_length: int = 4096,
    temperature: float = 0.8,
    stop_strings: List[str] = ["Goal Reached", "Exited"],
) -> List[str]:
    """
    Generates search paths starting from the given prompts.
    """
    # Encode prompts
    prev_padding_side = tokenizer.padding_side
    tokenizer.padding_side = "left"
    inputs = tokenizer(
        prompts,
        padding="longest",
        truncation=False,
        return_tensors="pt",
    )
    tokenizer.padding_side = prev_padding_side
    inputs = inputs.to("cuda")

    # Generate tokens
    if temperature == 0.0:
        all_tokens = model.generate(
            **inputs,
            max_length=max_length,
            do_sample=False,
            tokenizer=tokenizer,
            stop_strings=stop_strings,
            pad_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )
    else:
        all_tokens = model.generate(
            **inputs,
            max_length=max_length,
            do_sample=True,
            temperature=temperature,
            tokenizer=tokenizer,
            stop_strings=stop_strings,
            pad_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )

    # Get search paths
    search_paths = tokenizer.batch_decode(all_tokens, skip_special_tokens=True)

    return search_paths


def main(args):
    # Set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Build model
    model = AutoModelForCausalLM.from_pretrained(
        args.ckpt,
        device_map="cuda",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    model = model.cuda()
    model = model.eval()

    # Build tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.ckpt)

    # Load data
    data_file = os.path.join(DATA_DIR, DATA_FILES[args.split])
    with open(data_file, "r") as json_file:
        data = json.load(json_file)

    # Sample trajectories
    start = args.start
    end = min(start + args.offset, len(data))
    trajectories = []
    gen_ratings = []
    ref_ratings = []
    losses = []

    for batch_start in trange(start, end, args.batch_size):
        # Get samples
        batch_end = min(batch_start + args.batch_size, end)
        samples = data[batch_start:batch_end]

        # Get prompts
        prompts = get_prompts(samples, args.depth)

        # Get search paths
        ref_search_paths = get_search_paths(samples)

        # Generate search paths
        if args.depth < 3:
            gen_search_paths = generate(
                model,
                tokenizer,
                prompts,
                max_length=args.max_length,
                temperature=args.temperature,
            )
        else:
            gen_search_paths = prompts

        # Update trajectories
        trajectories += gen_search_paths

        # Compute ratings
        for gen_search_path, ref_search_path in zip(gen_search_paths, ref_search_paths):
            gen_rating, _ = metric_fn(gen_search_path)
            ref_rating, _ = metric_fn(ref_search_path)
            gen_ratings.append(gen_rating)
            ref_ratings.append(ref_rating)

        # Compute losses
        losses += forward(model, tokenizer, gen_search_paths)

        # Compute stats
        gen_rating = np.mean(gen_ratings)
        ref_rating = np.mean(ref_ratings)
        gen_accuracy = np.mean([r > 0 for r in gen_ratings])
        ref_accuracy = np.mean([r > 0 for r in ref_ratings])

        # Print stats
        print()
        print(f"Gen Rating: {gen_rating}, Gen Accuracy: {gen_accuracy}")
        print(f"Ref Rating: {ref_rating}, Ref Accuracy: {ref_accuracy}")

    # Save results
    results = {"trajectories": trajectories, "ratings": gen_ratings, "losses": losses}
    result_file = f"results_final_optimal_s{args.seed}_depth{args.depth}_{args.split}_{start}_{end}.json"
    result_file = os.path.join(args.ckpt, result_file)
    with open(result_file, "w") as f:
        json.dump(results, f, indent=4)


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", required=True, type=str)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--depth", default=0, type=int)
    parser.add_argument("--split", default="train", type=str)
    parser.add_argument("--start", default=0, type=int)
    parser.add_argument("--offset", default=1000, type=int)
    parser.add_argument("--batch_size", default=16, type=int)
    parser.add_argument("--prompt_len", default=17, type=int)
    parser.add_argument("--max_length", default=4096, type=int)
    parser.add_argument("--temperature", default=0.8, type=float)
    args = parser.parse_args()

    # Run main
    main(args)
