import json
import random
from datetime import datetime
import argparse
from tqdm import tqdm
import csv
import os

os.environ['HF_HOME'] = '../../_hf'
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To suppress warnings about parallelism in tokenizers
# logger = logging.getLogger(__name__)

from loguru import logger
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

from vds_load import Metric, NeoLoader
from vds_shared import MODEL_REGISTRY, REPORT_OUTS_DIR
from utils.dataset import *
from utils.template import *


def parse_args():
    parser = argparse.ArgumentParser(description="In-Context Learning baseline.")
    parser.add_argument(
        "--llm_code",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
    )
    parser.add_argument(
        "--n_train_shot",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
    )
    args = parser.parse_args()
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    return args


def llm_gen(model, prompt, tokenizer, max_context_len):
    inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=model.device)
    if inputs['input_ids'].shape[1] > max_context_len:
        inputs['input_ids'] = inputs['input_ids'][:, -max_context_len:]
        inputs['attention_mask'] = inputs['attention_mask'][:, -max_context_len:]
    with torch.no_grad():
        logits = model.forward(input_ids=inputs['input_ids'],
                               attention_mask=inputs['attention_mask'],
                               return_dict=True).logits.detach().cpu()
    # the output prob is shifted by -1, so we should use the output at the last input token position
    # gen_logits.shape = [1, 50257]
    gen_logits = logits[:, -1, :].float()

    return gen_logits


def parse_response(gen_logits, tokenizer, id2verb):
    gen_prob = torch.softmax(gen_logits, dim=-1)
    prob_per_cls = []
    for label_verb in id2verb:
        label_verb_token_id = tokenizer.encode(' ' + label_verb)[-1] # note the space before label word
        prob_per_cls.append(gen_prob[:, label_verb_token_id])
    pred = torch.argmax(torch.cat(prob_per_cls, dim=0)).tolist()
    return pred


def main():
    args = parse_args()

    # logging.basicConfig(
    #     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    #     datefmt="%m/%d/%Y %H:%M:%S",
    #     level=logging.INFO,
    # )
    # logger.setLevel(logging.INFO)
    logger.info(f"{args=}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    llm_name = MODEL_REGISTRY[args.llm_code]
    if 'gemma' in llm_name or 'Qwen' in llm_name or 'llama' in llm_name:
        tokenizer = NeoLoader.load_tokenizer(llm_name)
        model_config, model, _ = NeoLoader.load_model(llm_name)
    else:
        tokenizer = AutoTokenizer.from_pretrained(llm_name)
        # set pad token ids for batched inference cus gpt2 does not have one
        tokenizer.padding_side = "left"
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model_config = AutoConfig.from_pretrained(llm_name)
        model = AutoModelForCausalLM.from_pretrained(llm_name)
        model.to(device)
        model.eval()

    if 'gpt2' in llm_name:
        max_context_len = 1024
    else:
        max_context_len = 2048

    # prepare dataset
    train_data, dev_data = load_dataset(dataset=args.dataset)

    # inference
    train_data.subsamplebyshot(args.n_train_shot, args.seed)
    logger.info(f"===== eval on {dev_data.__len__()} dev examples =====")
    prompt_prefix = make_prompt(train_data, args.dataset, mode='train')
    dev_labels = []
    dev_pred = []
    label2id = dev_data.label2id
    id2verb = train_data.id2verb
    for ins in tqdm(dev_data.data, total=dev_data.__len__()):
        dev_labels.append(label2id[ins['label']])
        prompt = prompt_prefix + make_prompt(ins, args.dataset, mode='inference')
        gen_logits = llm_gen(model, prompt, tokenizer, max_context_len)
        dev_pred.append(parse_response(gen_logits, tokenizer, id2verb))

    acc = Metric.same_accuracy(dev_pred, dev_labels)
    logger.info(f"Acc: {acc}")
    Metric.general_gen_scoring(dev_pred, dev_labels)

    # logging
    REPORT_OUTS_DIR.mkdir(parents=True, exist_ok=True)
    save_results_file = REPORT_OUTS_DIR / 'summary_icl.csv'
    csv_exists = save_results_file.exists()
    with open(save_results_file, 'a+', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        if not csv_exists:
            csvwriter.writerow(['llm', 'dataset', 'acc'])
        csvwriter.writerow([args.llm_code, args.dataset, acc])


if __name__ == "__main__":
    main()
