import os
import json
import time
import argparse
from pathlib import Path
from collections import defaultdict

from accelerate.utils import set_seed

import torch
import torch.nn as nn

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.calibration import calibration_curve
from transformers import AutoTokenizer
from datasets import load_dataset

from models import LlamaDraftForCausalLM, LlamaForCausalLM
from models.token import KVCache
from dataloader import Dataset, DataCollator
from utils import Timer
from preprocess import get_tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache


@torch.no_grad()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir", type=Path, default="preprocessed_data")
    parser.add_argument(
        "--output_dir", type=Path, default="output")
    parser.add_argument(
        "--ea-model-path",
        type=str,
        default="checkpoint-60000",
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="meta-llama/Llama-3.1-8B-Instruct",
        help="The model ID",
    )
    parser.add_argument(
        "--max_length",
        type=int,
        default=512,
        help="The maximum length of the input sequence.",
    )
    parser.add_argument(
        "--depth",
        type=int,
        default=3,
        help="The number of depth to be used.",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Debug mode",
    )
    parser.add_argument(
        "--draft-as-label",
        action="store_true",
        help="Use draft as label",
    )
    parser.add_argument(
        "--nth-token",
        type=int,
        default=0,
        help="The nth token to be used for evaluation. Default is 0, which means the first token.",
    )

    args = parser.parse_args()

    with Timer("Loading tokenizer..."):
        tokenizer = get_tokenizer(args.model_name, max_length=args.max_length)

    with Timer("Loading dataset..."):
        valid_dataset = Dataset(
            args.data_dir,
            split="valid",
            debug=args.debug,
            add_feature_noise=False,
            tokenizer=tokenizer,
            max_length=args.max_length,
        )
    data_collator = DataCollator(tokenizer=tokenizer, pad_to_multiple_of=8)

    dataloader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=1,
        collate_fn=data_collator,
        num_workers=4,
        pin_memory=True,
    )

    with Timer("Loading model..."):
        model = LlamaForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype="float16",
            device_map="cuda",
        )
        base_model = model
        if not args.debug:
            model = LlamaDraftForCausalLM.from_pretrained(
                args.ea_model_path,
                torch_dtype="float16",
                device_map="cuda",
            )
            model.set_base_model(base_model.model)
            model.lm_head = base_model.lm_head
            model.draft_lm_head = base_model.lm_head
            model.model.embed_tokens = base_model.model.embed_tokens

        print(base_model.dtype)
        print(model.dtype)

    with Timer("Loading tokenizer..."):
        tokenizer = get_tokenizer(args.model_name)

    model.eval()
    print('Check model training state:', model.training)

    draft_past_key_values = KVCache(
        model.config.draft_num_hidden_layers,
        model.config.num_key_value_heads,
        model.config.max_position_embeddings * 8,
        model.config.head_dim,
        model.device,
        model.dtype,
        num_seqs=1,
    )
    base_past_key_values = KVCache(
        model.config.num_hidden_layers,
        model.config.num_key_value_heads,
        model.config.max_position_embeddings * 8,
        model.config.head_dim,
        model.device,
        model.dtype,
        num_seqs=1,
    )

    draft_output_log_probs_history = []
    base_output_log_probs_history = []
    valid_mask_history = []

    for batch_idx, batch in enumerate(tqdm(dataloader)):
        if batch_idx == 300:
            break
        labels = batch["labels"].cuda()
        loss_masks = batch["loss_masks"].cuda()
        hidden_states = batch["hidden_states"].cuda()
        output_topk_ids = batch["output_topk_ids"].cuda()

        last_input_ids = labels.clone()
        last_hidden_states = hidden_states.clone()

        draft_kv_cache_indices = []
        base_kv_cache_indices = []
        min_value = torch.finfo(model.dtype).min

        q_len = labels.size(1)
        q_len_valid = int(loss_masks.sum().item())

        if q_len_valid == 0:
            continue

        draft_output_log_probs_total = []
        base_output_log_probs_total = []
        valid_mask_total = []

        valid_mask = torch.ones_like(loss_masks, dtype=torch.bool)
        for i in range(args.depth):
            if i == 0:
                draft_new_kv_cache_indices = draft_past_key_values.allocate(q_len)
                base_new_kv_cache_indices = base_past_key_values.allocate(q_len)
            else:
                draft_new_kv_cache_indices = draft_past_key_values.allocate(q_len_valid)
                base_new_kv_cache_indices = base_past_key_values.allocate(q_len_valid)
            draft_kv_cache_indices.extend(draft_new_kv_cache_indices)
            base_kv_cache_indices.extend(base_new_kv_cache_indices)

            attention_mask = torch.full(
                (labels.size(0), 1, q_len, q_len + q_len_valid * i),
                min_value,
                device=model.device,
                dtype=model.dtype,
            )

            attention_mask[0, 0, :, :q_len] = torch.triu(
                attention_mask[0, 0, :, :q_len],
                diagonal=1,
            )
            if i > 0:
                attention_mask = attention_mask[:, :, loss_masks.bool().flatten(), :]
            for j in range(i):
                attention_mask[0, 0, :, q_len + j * q_len_valid:q_len + (j + 1) * q_len_valid].diagonal(0).fill_(0)

            draft_model_inputs = {
                "input_ids": last_input_ids,
                "position_ids": torch.arange(0, q_len, device=model.device).unsqueeze(0) + i,
                "hidden_states": last_hidden_states,
                "attention_mask": attention_mask,
                "past_key_values": draft_past_key_values,
                "past_key_value_indices": draft_kv_cache_indices,
                "use_cache": True,
            }
            base_model_inputs = {
                "input_ids": last_input_ids,
                "position_ids": torch.arange(0, q_len, device=model.device).unsqueeze(0) + i,
                "attention_mask": attention_mask,
                "past_key_values": base_past_key_values,
                "past_key_value_indices": base_kv_cache_indices,
                "use_cache": True,
            }
            if i > 0:
                draft_model_inputs["position_ids"] = draft_model_inputs["position_ids"][:, loss_masks.bool().flatten()]
                base_model_inputs["position_ids"] = base_model_inputs["position_ids"][:, loss_masks.bool().flatten()]

            draft_output = model(**draft_model_inputs, shift_tokens=i == 0, cut_last_token=i == 0)
            base_output = base_model(**base_model_inputs)

            draft_output_log_probs = draft_output["logits"].log_softmax(-1)
            base_output_log_probs = base_output["logits"].log_softmax(-1)

            draft_hidden_state = draft_output["draft_hidden_states"]

            if i == 0:
                last_hidden_states = last_hidden_states[:, loss_masks.bool().flatten()]
                draft_output_log_probs = draft_output_log_probs[:, loss_masks.bool().flatten()]
                base_output_log_probs = base_output_log_probs[:, loss_masks.bool().flatten()]
                draft_hidden_state = draft_hidden_state[:, loss_masks.bool().flatten()]

            if args.draft_as_label:
                output_ids = draft_output_log_probs.topk(5, dim=-1).indices[..., args.nth_token].reshape(1, q_len_valid)
                base_output_ids = base_output_log_probs.topk(5, dim=-1).indices[..., 0].reshape(1, q_len_valid)
                valid_mask = output_ids == base_output_ids
            else:
                output_ids = base_output_log_probs.topk(5, dim=-1).indices[..., 0].reshape(1, q_len_valid)

            last_input_ids = output_ids
            last_hidden_states = draft_hidden_state

            draft_output_log_probs_total.append(draft_output_log_probs)
            base_output_log_probs_total.append(base_output_log_probs)
            valid_mask_total.append(valid_mask)

        draft_output_log_probs_total = torch.cat(draft_output_log_probs_total, dim=0).cpu()
        base_output_log_probs_total = torch.cat(base_output_log_probs_total, dim=0).cpu()
        valid_mask_total = torch.cat(valid_mask_total, dim=0).cpu()

        draft_output_log_probs_history.append(draft_output_log_probs_total)
        base_output_log_probs_history.append(base_output_log_probs_total)
        valid_mask_history.append(valid_mask_total)

        draft_past_key_values.free(draft_kv_cache_indices)
        base_past_key_values.free(base_kv_cache_indices)

    draft_output_log_probs_history = torch.cat(draft_output_log_probs_history, dim=-2)
    base_output_log_probs_history = torch.cat(base_output_log_probs_history, dim=-2)
    valid_mask_history = torch.cat(valid_mask_history, dim=-1)

    def brier_score(y_true, y_prob):
        one_hot_y = torch.zeros_like(y_prob)
        one_hot_y.scatter_(1, y_true.argmax(dim=-1, keepdim=True), 1)
        brier_score = (y_prob - one_hot_y) ** 2
        brier_score = brier_score.sum(dim=-1).mean()
        return brier_score

    def expected_calibration_error(y_true, y_prob, n_bins=10):
        fraction_of_positives, mean_predicted_value = calibration_curve(
            y_true.flatten().cpu().numpy(),
            y_prob.flatten().cpu().numpy(),
            n_bins=n_bins,
        )
        ece = np.abs(fraction_of_positives - mean_predicted_value).mean()
        return ece

    def nll(y_true, y_prob):
        return -torch.log(y_prob).mean()

    brier_score_values = []
    ece_values = []
    nll_values = []
    acc_values = []

    valid_brier_score_values = []
    valid_ece_values = []
    valid_nll_values = []
    valid_acc_values = []

    for i in range(args.depth):
        plt.close()
        plt.clf()

        is_correct = (draft_output_log_probs_history[i].argmax(dim=-1) == base_output_log_probs_history[i].argmax(dim=-1)).float()
        probs = draft_output_log_probs_history[i].exp().max(dim=-1).values

        brier_score_value = brier_score(draft_output_log_probs_history[i].exp(), base_output_log_probs_history[i].exp())
        ece_value = expected_calibration_error(is_correct.flatten(), probs.flatten())
        nll_value = nll(is_correct.flatten(), probs.flatten())
        acc_value = is_correct.mean()

        print(f"Brier score: {brier_score_value.item()}")
        print(f"Expected calibration error: {ece_value}")
        print(f"Negative log likelihood: {nll_value.item()}")
        print(f"Accuracy: {acc_value.item()}")
        if i > 0:
            valid_brier_score_value = brier_score(
                draft_output_log_probs_history[i].exp()[valid_mask_history[i - 1]],
                base_output_log_probs_history[i].exp()[valid_mask_history[i - 1]],
            )
            valid_ece_value = expected_calibration_error(
                is_correct.flatten()[valid_mask_history[i - 1]],
                probs.flatten()[valid_mask_history[i - 1]],
            )
            valid_nll_value = nll(
                is_correct.flatten()[valid_mask_history[i - 1]],
                probs.flatten()[valid_mask_history[i - 1]],
            )
            valid_acc_value = is_correct[valid_mask_history[i - 1]].mean()
            print(f"Valid Brier score: {valid_brier_score_value.item()}")
            print(f"Valid expected calibration error: {valid_ece_value}")
            print(f"Valid negative log likelihood: {valid_nll_value.item()}")
            print(f"Valid accuracy: {valid_acc_value.item()}")

            valid_brier_score_values.append(valid_brier_score_value.item())
            valid_ece_values.append(valid_ece_value)
            valid_nll_values.append(valid_nll_value.item())
            valid_acc_values.append(valid_acc_value.item())

        brier_score_values.append(brier_score_value.item())
        ece_values.append(ece_value)
        nll_values.append(nll_value.item())
        acc_values.append(acc_value.item())

        # fraction_of_positives, mean_predicted_value = calibration_curve(
        #     is_correct.flatten().cpu().numpy(),
        #     probs.flatten().cpu().numpy(),
        #     n_bins=10,
        # )

        # sns.lineplot(
        #     x=mean_predicted_value,
        #     y=fraction_of_positives,
        # )
        # # sns.histplot(
        # #     x=mean_predicted_value,
        # #     y=fraction_of_positives,
        # #     bins=10,
        # #     kde=True,
        # # )

        # true_line_x = np.linspace(0, 1, 10)
        # true_line_y = np.linspace(0, 1, 10)
        # plt.plot(
        #     true_line_x,
        #     true_line_y,
        #     linestyle="--",
        #     color="red",
        #     label="Perfectly calibrated",
        # )
        # plt.xlabel("Mean predicted value")
        # plt.ylabel("Fraction of positives")

        # # save the plot
        # if args.draft_as_label:
        #     plt.title(f"Draft as label - Depth {i}")
        #     save_path = args.output_dir  / f"depth_{i}_draft_as_label_calibration_curve.png"
        # else:
        #     plt.title(f"Base as label - Depth {i}")
        #     save_path = args.output_dir  / f"depth_{i}_base_as_label_calibration_curve.png"
        # os.makedirs(os.path.dirname(save_path), exist_ok=True)
        # plt.savefig(save_path)
        # plt.close()
        # plt.clf()

    print("Brier score values:", brier_score_values)
    print("ECE values:", ece_values)
    print("NLL values:", nll_values)
    print("Accuracy values:", acc_values)

    print("Valid Brier score values:", valid_brier_score_values)
    print("Valid ECE values:", valid_ece_values)
    print("Valid NLL values:", valid_nll_values)
    print("Valid Accuracy values:", valid_acc_values)


if __name__ == "__main__":
    main()
