import os

os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
os.environ["DS_SKIP_CUDA_CHECK"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HOME"] = "/home/user/huggingface"

import argparse
import json
import random
from typing import Dict, List

import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

DATA_DIR = "/home/user/train-countdown/stream-of-search/src/data/b4_rand"
DATA_FILES = {
    "train": "train1_b4_t100_n500000_random.json",
    "val": "val1_b4_t100_n500000_random.json",
    "test": "val_target1_b4_t100_n500000_random.json",
}


def get_prompts(samples: List[Dict[str, str]]) -> List[str]:
    """
    Returns prompts from the given samples.
    """
    prompts = []
    for sample in samples:
        optimal_path = sample["optimal_path"].strip()
        prompt = optimal_path.split("\n")[0] + "\n"
        prompts.append(prompt)
    return prompts


def forward(
    model: PreTrainedModel,
    score: nn.Linear,
    tokenizer: PreTrainedTokenizer,
    prompts: List[str],
) -> List[float]:
    """
    Compute values for the given prompts.
    """
    # Encode prompts
    inputs = tokenizer(
        prompts,
        padding=False,
        truncation=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Compute values
    with torch.no_grad():
        outputs = model.forward(**inputs, output_hidden_states=True)
        hidden = outputs.hidden_states[-1][:, -1]
        values = score.forward(hidden).squeeze(-1)
        values = values.tolist()

    return values


def main(args):
    # Set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Build critic model
    ckpt = os.path.join(args.base_dir, f"model_{args.epoch}")
    critic_ckpt = os.path.join(ckpt, "critic")
    critic_model = AutoModelForCausalLM.from_pretrained(
        critic_ckpt,
        device_map="cuda",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    critic_model = critic_model.cuda()
    critic_model = critic_model.eval()

    # Build critic score
    critic_score = nn.Linear(critic_model.config.hidden_size, 1, dtype=torch.bfloat16)
    critic_score.load_state_dict(torch.load(os.path.join(critic_ckpt, "score.pt")))
    critic_score = critic_score.cuda()
    critic_score = critic_score.eval()

    # Build tokenizer
    tokenizer = AutoTokenizer.from_pretrained(ckpt)

    # Load data
    data_file = os.path.join(DATA_DIR, DATA_FILES[args.split])
    with open(data_file, "r") as json_file:
        data = json.load(json_file)

    # Compute values
    start = args.start
    end = min(start + args.offset, len(data))
    values = []
    for batch_start in trange(start, end, args.batch_size):
        batch_end = min(batch_start + args.batch_size, end)
        samples = data[batch_start:batch_end]
        prompts = get_prompts(samples)
        values += forward(critic_model, critic_score, tokenizer, prompts)

    # Save values
    values = {"values": values}
    value_file = f"values_final_s{args.seed}_{args.split}_{start}_{end}.json"
    value_file = os.path.join(ckpt, value_file)
    with open(value_file, "w") as f:
        json.dump(values, f, indent=4)


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_dir", required=True, type=str)
    parser.add_argument("--epoch", default=0, type=int)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--split", default="train", type=str)
    parser.add_argument("--start", default=0, type=int)
    parser.add_argument("--offset", default=1000, type=int)
    parser.add_argument("--batch_size", default=100, type=int)
    args = parser.parse_args()

    # Run main
    main(args)
