import torch
from torch.utils.data import DataLoader
from transformers import PretrainedConfig
from loader.model import load_model
from loader.checkpoint import get_checkpoint_id, load_tokenizer
from loader.data import _load_data, SimpleDataCollator, SimpleDataCollatorMatrix
from safetensors import safe_open
import click
import yaml
import os
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import re
import pandas as pd
import matplotlib.pyplot as plt


def matrix_evaluate(model, testloader, tokenizer):
    hits, total = 0, 0
    matrix_dict = defaultdict(lambda: defaultdict(int))
    for batch in testloader:
        # batch = {k: v.cuda() for k, v in batch.items()}
        max_length = batch["labels"].shape[1]
        matrix_ids = (
            str(batch["matrix_ids"]).split(", ")[0],
            str(batch["matrix_ids"]).split(", ")[1],
        )  # TODO: Only supports a batch size of 100
        matrix_ids = [int("".join(re.findall(r"\d+", i))) for i in matrix_ids]
        outputs = model.greedy_generate(
            batch["encoder_input"].cuda(),
            max_length=max_length,
        )
        pred = tokenizer.batch_decode(outputs.cpu().numpy(), skip_special_tokens=True)
        target = tokenizer.batch_decode(batch["labels"].cpu().numpy(), skip_special_tokens=True)
        _hits = np.array([p == t for p, t in zip(pred, target)], dtype=float)
        hits += _hits.sum()
        total += len(pred)
        _hits = np.array([p == t for p, t in zip(pred, target)])
        matrix_dict[matrix_ids[0]][matrix_ids[1]] += int(_hits.sum())
        # breakpoint()

    print(f"Accuracy: {hits/total}")
    for i, v in matrix_dict.items():
        for j, vv in v.items():
            print(f"({i}, {j}) : {vv}")

    return matrix_dict, hits / total


def evaluate(model, testloader, tokenizer):
    hits, total = 0, 0
    for batch in testloader:
        batch = {k: v.cuda() for k, v in batch.items()}
        max_length = batch["labels"].shape[1]
        outputs = model.greedy_generate(
            batch["encoder_input"],
            max_length=max_length,
        )
        pred = tokenizer.batch_decode(outputs.cpu().numpy(), skip_special_tokens=True)
        target = tokenizer.batch_decode(batch["labels"].cpu().numpy(), skip_special_tokens=True)
        _hits = np.array([p == t for p, t in zip(pred, target)], dtype=float)
        hits += _hits.sum()
        total += len(pred)
        _hits = np.array([p == t for p, t in zip(pred, target)])
        # if not np.all(_hits):
        #     print("pred:\n", np.array(pred)[~_hits][0])
        #     print("target:\n", np.array(target)[~_hits][0])
        #     print("---")

    print(f"Accuracy: {hits/total}")
    return hits / total



@click.command()
@click.option(
    "--params-cfg",
    default="/mnt/nfs/results/large/prod_padding/prod_padding_ZZ_digits=5_standard_bs=128_2/params.yaml",
    help="Path to config file",
)
@click.option("--data-path", default=None, help="Path to data file")
@click.option("--use-matrix", is_flag=True, help="Use matrix")
def main(params_cfg, data_path, use_matrix):
    cfg = yaml.load(open(params_cfg), Loader=yaml.FullLoader)
    save_path = cfg["save_path"]
    if data_path is None:
        data_path = cfg["data_path"]
    cpid = get_checkpoint_id(save_path)
    ckpt_path = os.path.join(save_path, f"checkpoint-{cpid}")
    config = PretrainedConfig.from_pretrained(os.path.join(ckpt_path, "config.json"))

    # load tokenizer and model
    tokenizer = load_tokenizer(save_path)
    model = load_model(config, vocab=tokenizer.vocab, tokenizer=tokenizer)
    model.cuda().eval()
    state_dict = {}
    with safe_open(os.path.join(ckpt_path, "model.safetensors"), framework="pt", device=0) as f:
        for k in f.keys():
            state_dict[k] = f.get_tensor(k)
    model.load_state_dict(state_dict)

    # load data
    testset = _load_data(f"{data_path}.test")
    if use_matrix:
        dc = SimpleDataCollatorMatrix(tokenizer)
    else:
        dc = SimpleDataCollator(tokenizer)
    testloader = DataLoader(testset, batch_size=100, collate_fn=dc, shuffle=False)

    # evaluate
    if use_matrix:
        matrix, acc = matrix_evaluate(model, testloader, tokenizer)
        df = pd.DataFrame(matrix).T
        fig, ax = plt.subplots(figsize=(8, 6))
        cax = ax.matshow(df, cmap="Blues", vmin=0, vmax=150)

        for (i, j), val in np.ndenumerate(df.values):
            ax.text(j, i, f"{val}", ha="center", va="center", color="black")

        ax.set_xticks(np.arange(len(df.columns)))
        ax.set_yticks(np.arange(len(df.index)))
        ax.set_xticklabels(df.columns)
        ax.set_yticklabels(df.index)
        ax.set_xlabel("digits")
        ax.set_ylabel("digits")
        ax.set_title(f"Accuracy Matrix: {acc:.3f}")

        # plt.colorbar(cax)
        plt.tight_layout()
        plt.savefig(os.path.join(os.path.dirname(params_cfg), "matrix.png"), dpi=300)
        plt.close()
    else:
        acc = evaluate(model, testloader, tokenizer)
        # print(f"Accuracy: {acc}")

    


if __name__ == "__main__":
    main()
