import copy
import json
import random
from datetime import datetime
import argparse
from tqdm import tqdm
import csv
import os

os.environ['HF_HOME'] = '../../_hf'
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To avoid warnings about parallelism in tokenizers
# logger = logging.getLogger(__name__)

from loguru import logger
import torch
from torch import nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

from vds_load import Metric, NeoLoader
from vds_shared import MODEL_REGISTRY, REPORT_OUTS_DIR
from utils.dataset import *
from utils.template import *


class AnchorStore(nn.Module):
    def __init__(self, K=1024, dim=50257, knn=1, n_class=2):
        super(AnchorStore, self).__init__()

        self.register_buffer("queue_anchor", torch.randn(K, dim))
        self.register_buffer("queue_label", torch.zeros(K, dtype=torch.int32))
        self.queue_label.fill_(-1)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.int32))
        self.knn = knn
        self.n_class = n_class

    def enqueue(self, anchors, labels):

        ptr = int(self.queue_ptr)
        bs = anchors.shape[0]

        self.queue_anchor[ptr:ptr + bs, :] = anchors
        self.queue_label[ptr:ptr + bs] = labels
        self.queue_ptr[0] = ptr + bs

    def knn_infer(self, query):

        # kl_div.shape = [1, len(self.queue_anchor)]
        kl_distance = torch.mean(self.queue_anchor[:, None, :] * (self.queue_anchor[:, None, :].log() - query.log()), dim=2).transpose(1, 0)
        if self.knn == 1:
            # directly return the nearest neighbor
            return self.queue_label[kl_distance.argmin(dim=1)].tolist()
        else:
            values, indices = torch.topk(kl_distance, self.knn, dim=1, largest=False)
            # count for each category within k nearest neighbors, and return the dominant category
            # knn_cnt.shape = [1, self.n_class]
            knn_cnt = torch.zeros((query.shape[0], self.n_class))
            for i in range(self.n_class):
                knn_cnt[:, i] = (self.queue_label[indices] == i).sum(dim=1)
            return knn_cnt.argmax(dim=1).tolist()


def parse_args():
    parser = argparse.ArgumentParser(description="KNN Prompting.")
    parser.add_argument(
        "--llm_code",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
    )
    parser.add_argument(
        "--knn",
        type=int,
        default=3,
    )
    parser.add_argument(
        "--n_train_shot",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--n_demo_shot",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--n_anchor_shot",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
    )
    args = parser.parse_args()
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    return args


def llm_gen(model, prompt, tokenizer, max_context_len):
    inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=model.device)
    if inputs['input_ids'].shape[1] > max_context_len:
        inputs['input_ids'] = inputs['input_ids'][:, -max_context_len:]
        inputs['attention_mask'] = inputs['attention_mask'][:, -max_context_len:]
    with torch.no_grad():
        logits = model.forward(input_ids=inputs['input_ids'],
                               attention_mask=inputs['attention_mask'],
                               return_dict=True).logits.detach().cpu()
    # the output prob is shifted by -1, so we should use the output at the last input token position
    # gen_logits.shape = [1, 50257]
    gen_logits = logits[:, -1, :].float()

    return gen_logits


def main():
    args = parse_args()

    args.n_anchor_shot = args.n_train_shot - args.n_demo_shot
    if args.n_anchor_shot <= 0:
        raise Exception("Num. of demonstration must be set smaller than num. of training.")

    args.knn = min(args.knn, args.n_anchor_shot)  # knn can not exceed num. of anchors

    # logging.basicConfig(
    #     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    #     datefmt="%m/%d/%Y %H:%M:%S",
    #     level=logging.INFO,
    # )
    # logger.setLevel(logging.INFO)
    logger.info(f"{args=}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    llm_name = MODEL_REGISTRY[args.llm_code]
    if 'gemma' in llm_name or 'Qwen' in llm_name or 'llama' in llm_name:
        tokenizer = NeoLoader.load_tokenizer(llm_name)
        model_config, model, _ = NeoLoader.load_model(llm_name)
    else:
        tokenizer = AutoTokenizer.from_pretrained(llm_name)
        # set pad token ids for batched inference cus gpt2 does not have one
        tokenizer.padding_side = "left"
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model_config = AutoConfig.from_pretrained(llm_name)
        model = AutoModelForCausalLM.from_pretrained(llm_name)
        model.to(device)
        model.eval()

    if 'gpt2' in llm_name:
        max_context_len = 1024
    else:
        max_context_len = 2048

    # prepare dataset
    train_data, dev_data = load_dataset(dataset=args.dataset)
    anchor_data = copy.deepcopy(train_data)

    # Stage1: Meta Test
    train_data.subsamplebyshot(args.n_demo_shot, args.seed)
    prompt_prefix = make_prompt(train_data, args.dataset, mode='train')
    anchor_data.subsamplebyshot(args.n_anchor_shot, args.seed, exclude=train_data.data)
    label2id = dev_data.label2id
    id2verb = train_data.id2verb
    logger.info(f"===== build anchor store of {anchor_data.__len__()} anchor examples =====")
    anchor_store = AnchorStore(K=anchor_data.__len__(),
                               dim=model_config.vocab_size,
                               knn=args.knn,
                               n_class=len(label2id))
    for ins in tqdm(anchor_data.data, total=anchor_data.__len__()):
        labels = label2id[ins['label']]
        prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference')
        gen_logits = llm_gen(model, prompt, tokenizer, max_context_len)
        anchor_store.enqueue(torch.softmax(gen_logits, dim=-1), torch.tensor(labels))

    # Stage2: Formal Test
    logger.info(f"===== eval on {dev_data.__len__()} dev examples =====")
    dev_labels = []
    dev_pred = []
    for ins in tqdm(dev_data.data, total=dev_data.__len__()):
        dev_labels.append(label2id[ins['label']])
        prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference')
        gen_logits = llm_gen(model, prompt, tokenizer, max_context_len)
        dev_pred.extend(anchor_store.knn_infer(torch.softmax(gen_logits, dim=-1)))

    acc = Metric.same_accuracy(dev_pred, dev_labels)
    logger.info(f"Acc: {acc}")
    Metric.general_gen_scoring(dev_pred, dev_labels)

    # logging
    REPORT_OUTS_DIR.mkdir(parents=True, exist_ok=True)
    save_results_file = REPORT_OUTS_DIR / 'summary_kp.csv'
    csv_exists = save_results_file.exists()
    with open(save_results_file, 'a+', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        if not csv_exists:
            csvwriter.writerow(['llm', 'dataset', 'acc'])
        csvwriter.writerow([args.llm_code, args.dataset, acc])


if __name__ == "__main__":
    main()
