import os

os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
os.environ["DS_SKIP_CUDA_CHECK"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HOME"] = "/home/user/huggingface"

import argparse
import json
import random
from typing import List

import numpy as np
import torch
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)


def forward(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    trajectories: List[str],
    ignore_index: int = -100,
) -> List[float]:
    """
    Computes losses for the given trajectories.
    """
    # Encode trajectories
    inputs = tokenizer(
        trajectories,
        padding="longest",
        truncation=False,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Define labels
    input_ids: torch.Tensor = inputs["input_ids"]
    labels = input_ids.clone()
    labels[labels == tokenizer.pad_token_id] = ignore_index

    # Compute losses
    with torch.no_grad():
        outputs = model.forward(**inputs)
        logits: torch.Tensor = outputs["logits"]
        logits = logits[:, :-1]
        labels = labels[:, 1:]
        b, t = labels.shape
        logits = logits.reshape(b * t, -1)
        labels = labels.reshape(b * t)
        losses = F.cross_entropy(
            logits, labels, ignore_index=ignore_index, reduction="none"
        )
        losses = losses.reshape(b, t)
        labels = labels.reshape(b, t)
        losses = losses.sum(dim=1) / (labels != ignore_index).sum(dim=1)
        losses = losses.tolist()

    return losses


def main(args):
    # Set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Build model
    ckpt = os.path.join(args.base_dir, "sft-gpt2", "checkpoint-50000")
    model = AutoModelForCausalLM.from_pretrained(
        ckpt,
        device_map="cuda",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    model = model.cuda()
    model = model.eval()

    # Build tokenizer
    tokenizer = AutoTokenizer.from_pretrained(ckpt)

    # Define result directory
    if args.star_iter == 0:
        result_dir = ckpt
    elif args.star_iter == 1 or args.star_iter == 2:
        exp_name = f"star{args.star_iter}-final-{args.mode}-s{args.seed}-gpt2"
        result_dir = os.path.join(args.base_dir, exp_name, "checkpoint-20000")
    else:
        raise ValueError(f"Invalid star_iter: {args.star_iter}")
    start = args.start
    end = start + args.offset

    # Load base results
    base_result_file = f"results_final_s{args.seed}_{args.split}_{start}_{end}.json"
    base_result_file = os.path.join(result_dir, base_result_file)
    with open(base_result_file, "r") as f:
        base_results = json.load(f)
        base_trajectories = base_results["trajectories"]
        base_ratings = base_results["ratings"]

    # Load new results
    new_result_file = (
        f"results_final_{args.mode}_s{args.seed}_depth2_{args.split}_{start}_{end}.json"
    )
    new_result_file = os.path.join(result_dir, new_result_file)
    with open(new_result_file, "r") as f:
        new_results = json.load(f)
        new_trajectories = new_results["trajectories"]
        new_ratings = new_results["ratings"]

    # Check the number of trajectories
    assert len(base_trajectories) == len(new_trajectories)

    # Filter base trajectories
    base_trajectories = [
        base_trajectories[i]
        for i in range(len(base_trajectories))
        if base_ratings[i] > 0.0
    ]

    # Filter new trajectories
    new_trajectories = [
        new_trajectories[i]
        for i in range(len(new_trajectories))
        if base_ratings[i] == 0.0 and new_ratings[i] > 0.0
    ]

    # Compute base losses
    base_losses = []
    for batch_start in range(0, len(base_trajectories), args.batch_size):
        batch_end = min(batch_start + args.batch_size, len(base_trajectories))
        base_trajectories_batch = base_trajectories[batch_start:batch_end]
        base_losses += forward(model, tokenizer, base_trajectories_batch)

    # Compute new losses
    new_losses = []
    for batch_start in range(0, len(new_trajectories), args.batch_size):
        batch_end = min(batch_start + args.batch_size, len(new_trajectories))
        new_trajectories_batch = new_trajectories[batch_start:batch_end]
        new_losses += forward(model, tokenizer, new_trajectories_batch)

    # Save losses
    losses = {"base_losses": base_losses, "new_losses": new_losses}
    loss_file = f"losses_final_{args.mode}_s{args.seed}_{args.split}_{start}_{end}.json"
    loss_file = os.path.join(result_dir, loss_file)
    with open(loss_file, "w") as f:
        json.dump(losses, f, indent=4)


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_dir", required=True, type=str)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--star_iter", default=0, type=int)
    parser.add_argument("--mode", default="rand", type=str)
    parser.add_argument("--split", default="train", type=str)
    parser.add_argument("--start", default=0, type=int)
    parser.add_argument("--offset", default=1000, type=int)
    parser.add_argument("--batch_size", default=16, type=int)
    args = parser.parse_args()

    # Run main
    main(args)
