import os
import sys
from os.path import join
from typing import List
import yaml
from contextlib import nullcontext
import time
from tqdm import tqdm
import random
from sklearn.metrics.pairwise import cosine_similarity
import json
import string
from collections import Counter, defaultdict
import numpy as np
import random
import fire
import torch
from torch.nn import CrossEntropyLoss
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
from transformers import default_data_collator
from utils.dataset_utils import get_dataset
from utils.icl_utils import get_icl_examples
from utils.prompt_utils import apply_icl_prompt
from utils.model_utils import (
    setup,
    setup_environ_flags,
    clear_gpu_cache,
    load_peft_model,
)
import math

def evaluate(model, eval_dataloader,tokenizer,options):
    model.eval()
    selected_logits = np.empty((0,options))
    correct_pred = 0
    for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
        for key in batch.keys():
            batch[key] = batch[key].to('cuda:0')
            
        if batch['gold'].shape[0] > 0:
            loss_list = []
            selected_logit = []
            gold = batch["gold"]
            batch_size = gold.shape[0]
            
            logit_location = np.array([]).reshape(batch_size, 0)
            for i in range(batch["input_ids"].shape[1]):
                b = {"input_ids": batch["input_ids"][:,i,:], "labels": batch["labels"][:,i,:], "attention_mask": batch["attention_mask"][:,i,:]}
                
                shift_labels = b["labels"][..., 1:].contiguous()
                label_length = torch.sum(shift_labels > 0, dim=-1)
                label_location = torch.where(shift_labels>0)
                batch_indices = label_location[0]
                seq_indices = label_location[1]
                current_labels = shift_labels[label_location].cpu().numpy()
                logit_location = np.column_stack((logit_location, current_labels))
                
            with torch.no_grad():
                outputs = model(**b)
                logits = outputs.logits
                
                shift_logits = logits[..., :-1, :].contiguous()

                for j in range(logit_location.shape[1]):
                    logits_to_append = shift_logits[batch_indices, seq_indices, logit_location[:, j]].detach().cpu().numpy()
                    logits_to_append = logits_to_append.reshape(-1, 1)  
                    selected_logit.append(logits_to_append)

                selected_logit = np.hstack(selected_logit)
            selected_logits = np.vstack((selected_logits,selected_logit))

            gold = gold.cpu().numpy()
            correct_pred += np.sum((selected_logit.argmax(axis=1)).reshape(-1) == gold)

    return np.array(selected_logits), correct_pred

def main(
    peft_model: str=None,
    quantization: bool=False,
    model_name: str="llama-3.1-8b",
    train_dataset: str="arc-easy", 
    test_dataset: str="arc-easy",
    seed: int=42,
    subset_size: int=100,
    k=4,
    method="base",
    dp_choice="knn",
    emb: str="all-roberta-large-v1",
    metric: str="cosine_similarity",
    if_qwa=False,
    permutation=1,
    batch_size=2,
):  
    print(f"method: {method}, subset_size: {subset_size}")
    print(f"start time: {time.asctime(time.localtime(time.time()))}")
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    if not isinstance(emb, str):
        raise TypeError("Main Expected 'emb' to be a string, but got wrong type") 
    
    # you need to change the model path
    model_path = join("/home/amax/exp/huggingface/transformers/",model_name)
    
    world_size = os.environ.get('WORLD_SIZE')
    if world_size is None:
        world_size = 1
        local_rank = 0
        rank = 0
    else:
        world_size = int(world_size)
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ['RANK'])
        print(f"rank: {rank} local rank: {local_rank} world size: {world_size}")

        setup()
        torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)

    idx_mat_origin = get_icl_examples(train_dataset=train_dataset,test_dataset=test_dataset,emb=emb,shuffle_seed=seed,k=k,method=method,dp_choice=dp_choice,subset_size=subset_size,metric=metric,model_name=model_name,permutation=permutation)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        return_dict=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        )
    
    if peft_model:
        model = load_peft_model(model, peft_model)
    
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    dataset_train, _ = get_dataset(dataset=train_dataset)
    _, dataset_test = get_dataset(dataset=test_dataset)
    options = len(dataset_train[0]['choices']['text'])
    
    np.random.seed(seed)
    random.seed(seed)
    permutations = []
    length = 0

    if(k>3):
        per_total = 10
    else:
        per_total = math.factorial(k)
    while length < per_total:
        perm = np.random.permutation(k)
        if(perm.tolist() not in permutations):
            permutations.append(perm.tolist())
            length += 1

    if("knn" in method or "knn_diversity" in method or "voke_k" in method or "k_means" in method):
        permutation_num = per_total
    else:
        permutation_num = 1
    
    for permutation_index in range(0, permutation_num):
        print(f"model_name: {model_name}, train_dataset: {train_dataset}, test_dataset: {test_dataset}, emb: {emb}, k: {k}, method: {method}, dp_choice: {dp_choice}, subset_size: {subset_size}, permutation: {permutation_index}, seed: {seed}")
        if k == 0:
            idx_mat = idx_mat_origin
        else:
            idx_mat = idx_mat_origin[:, np.array(permutations[permutation_index])]

        idx_mat_dir = f"./idx_mat/{method}/{model_name}/{test_dataset}/{train_dataset}/{k}/{seed}/{permutation_index}"
        if not os.path.exists(idx_mat_dir):
            os.makedirs(idx_mat_dir)
        np.save(f"{idx_mat_dir}/{emb}.npy", idx_mat)
        print(f"idx_mat has been saved in {idx_mat_dir}/{emb}.npy")
        icl_tokenized_dataset = apply_icl_prompt(batch_size, test_dataset, dataset_train, dataset_test, idx_mat, tokenizer)
        eval_dataloader = torch.utils.data.DataLoader(
            icl_tokenized_dataset,
            batch_size=batch_size,
            collate_fn=default_data_collator,
        )

        logit, correct_pred = evaluate(model, eval_dataloader,tokenizer, options)
        acc =correct_pred/len(dataset_test)
        print(f"accuracy: {acc}")

        if not os.path.exists(f"./results/{method}/{model_name}/{test_dataset}/{k}"):
            os.makedirs(f"./results/{method}/{model_name}/{test_dataset}/{k}")
        with open(f"./results/{method}/{model_name}/{test_dataset}/{k}/{emb}.txt", "a") as f:
            f.write(f"model_name: {model_name}, k: {k}, method: {method}, dp_choice: {dp_choice}, subset_size: {subset_size}, emb: {emb}, metric: {metric}, permutation: {permutation_index}, seed: {seed}\n")
            f.write(f"train_dataset: {train_dataset}\n")
            f.write(f"accuracy: {acc}\n")
        print(f"file has been saved in ./results/{method}/{model_name}/{test_dataset}/{k}/{emb}.txt")
        if not os.path.exists(f"./results/logit/{method}/{model_name}/{test_dataset}/{k}"):
            os.makedirs(f"./results/logit/{method}/{model_name}/{test_dataset}/{k}")
        np.save(f"./results/logit/{method}/{model_name}/{test_dataset}/{k}/{train_dataset}_{seed}.npy", logit)
    
    print(f"end time: {time.asctime(time.localtime(time.time()))}")
if __name__ == '__main__':
    fire.Fire(main)