import os
import sys
from os.path import join
from typing import List
import yaml
from contextlib import nullcontext
import time
from tqdm import tqdm
import random
from sklearn.metrics.pairwise import cosine_similarity
import json
import string
from collections import Counter, defaultdict
import numpy as np
import random
import fire
import torch
from torch.nn import CrossEntropyLoss
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
from transformers import default_data_collator
from utils.dataset_utils import get_dataset
from utils.icl_utils import get_icl_examples
from utils.prompt_utils import apply_icl_prompt
from utils.model_utils import (
    setup,
    setup_environ_flags,
    clear_gpu_cache,
    load_peft_model,
)
import math

def evaluate(model, eval_dataloader,tokenizer):
    model.eval()
    eval_preds = []
    eval_loss = 0.0
    eval_cls_preds = []
    eval_cls_acc = 0.0

    for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
        for key in batch.keys():
            batch[key] = batch[key].to('cuda:0')
            
        if batch['gold'].shape[0] > 0:
            loss_list = []
            gold = batch["gold"]
            batch_size = gold.shape[0]
            for i in range(batch["input_ids"].shape[1]):
                b = {"input_ids": batch["input_ids"][:,i,:], "labels": batch["labels"][:,i,:], "attention_mask": batch["attention_mask"][:,i,:]}
                with torch.no_grad():
                    outputs = model(**b)
                    logits = outputs.logits
                    
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = b["labels"][..., 1:].contiguous()
                    label_length = torch.sum(shift_labels > 0, dim=-1)
                    
                    loss_fct = CrossEntropyLoss(reduce=False)
                    shift_logits = shift_logits.view(-1, shift_logits.shape[-1])
                    shift_labels = shift_labels.view(-1)
                    shift_labels = shift_labels.to(shift_logits.device)
                    loss = loss_fct(shift_logits, shift_labels).view(batch_size, -1)
                    loss = torch.sum(loss, dim=-1)
                    loss = loss / label_length
                    loss_list.append(loss.detach())
            
            loss = torch.stack(loss_list, dim=-1)
            eval_loss += torch.mean(loss[torch.arange(loss.shape[0], device=gold.device), gold]).detach().float()
            cls_preds = torch.argmin(loss, dim=-1)
            eval_cls_preds.extend(cls_preds.detach().cpu().numpy())
            eval_cls_acc += torch.mean((cls_preds == gold).float()).detach().float()
    
    eval_epoch_acc = eval_cls_acc / len(eval_dataloader)
    
    return eval_epoch_acc

def main(
    peft_model: str=None,
    quantization: bool=False,
    model_name: str="llama-3.1-8b",
    train_dataset: str="arc-easy",
    test_dataset: str="arc-easy",
    seed: int=42,
    subset_size: int=100,
    k=4,
    method="base",
    dp_choice="knn",
    emb: str="all-roberta-large-v1",
    metric: str="cosine_similarity",
    if_qwa=False,
    permutation=1,
    batch_size=8,
):
    print(f"start time: {time.asctime(time.localtime(time.time()))}")
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    if not isinstance(emb, str):
        raise TypeError("Main Expected 'emb' to be a string, but got wrong type")
    model_path = join("/home/amax/exp/huggingface/transformers/",model_name)
    world_size = os.environ.get('WORLD_SIZE')
    if world_size is None:
        world_size = 1
        local_rank = 0
        rank = 0
    else:
        world_size = int(world_size)
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ['RANK'])
        print(f"rank: {rank} local rank: {local_rank} world size: {world_size}")

        setup()
        torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)
    
    idx_mat_origin = get_icl_examples(train_dataset=train_dataset,test_dataset=test_dataset,emb=emb,shuffle_seed=seed,k=k,method=method,dp_choice=dp_choice,subset_size=subset_size,metric=metric,model_name=model_name,permutation=permutation)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        return_dict=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        )
    
    if peft_model:
        model = load_peft_model(model, peft_model)
    
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token

    dataset_train, _ = get_dataset(dataset=train_dataset)
    _, dataset_test = get_dataset(dataset=test_dataset)

    np.random.seed(seed)
    random.seed(seed)
    permutations = []
    length = 0

    if(k>3):
        per_total = 10
    else:
        per_total = math.factorial(k)

    while length < per_total:
        perm = np.random.permutation(k)
        if(perm.tolist() not in permutations):
            permutations.append(perm.tolist())
            length += 1
            
    if("knn" in method or "k_means" in method):
        permutation_num = per_total
    else:
        permutation_num = 1

    for permutation_index in range(0, permutation_num):
        print(f"model_name: {model_name}, train_dataset: {train_dataset}, test_dataset: {test_dataset}, emb: {emb}, k: {k}, method: {method}, dp_choice: {dp_choice}, subset_size: {subset_size}, permutation: {permutation_index}, seed: {seed}")
        if k == 0:
            idx_mat = idx_mat_origin
        else:
            idx_mat = idx_mat_origin[:, np.array(permutations[permutation_index])]
        icl_tokenized_dataset = apply_icl_prompt(batch_size, test_dataset, dataset_train, dataset_test, idx_mat, tokenizer)

        eval_dataloader = torch.utils.data.DataLoader(
            icl_tokenized_dataset,
            batch_size=batch_size,
            collate_fn=default_data_collator,
        )

        acc = evaluate(model, eval_dataloader,tokenizer)
        print(f"accuracy: {acc}")
        

        if not os.path.exists(f"./results/{method}/{model_name}/{test_dataset}/{k}"):
            os.makedirs(f"./results/{method}/{model_name}/{test_dataset}/{k}")
        with open(f"./results/{method}/{model_name}/{test_dataset}/{k}/{emb}.txt", "a") as f:
            f.write(f"model_name: {model_name}, k: {k}, method: {method}, dp_choice: {dp_choice}, subset_size: {subset_size}, emb: {emb}, metric: {metric}, permutation: {permutation_index}, seed: {seed}\n")
            f.write(f"train_dataset: {train_dataset}\n")
            f.write(f"accuracy: {acc}\n")
        print(f"file has been saved in ./results/{method}/{model_name}/{test_dataset}/{k}/{emb}.txt")
    print(f"end time: {time.asctime(time.localtime(time.time()))}")
if __name__ == '__main__':
    fire.Fire(main)