import sys
import os
import torch
import itertools
import argparse
import pickle as pkl
import random
import torch
import math
import json
import string
import logging
import numpy as np
import pdb
from tqdm import tqdm
from collections import Counter, defaultdict
from utils.dataset_utils import get_dataset

from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from transformers import GPT2Tokenizer, AutoTokenizer, AutoModelForSequenceClassification
from transformers import RobertaTokenizer, RobertaModel, GPTJForCausalLM
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM
from transformers import pipeline
from sentence_transformers import SentenceTransformer


def main(logger, args):
    
    local_cache_dir = os.path.expanduser("/home/amax/exp/huggingface/sentence_transformers")
    
    if not os.path.exists("data/"+args.dataset):
        os.makedirs("data/"+args.dataset)
        
    device= "cuda:0" if torch.cuda.is_available() else "cpu"
    if(args.model_name != "all-roberta-large-v1" and args.model_name != "all-mpnet-base-v2" and args.model_name != "all-minilm-l6-v2" and args.model_name != "sbert_large_mt_nlu_ru"):
        raise KeyError("No embedding for your input！")
    
    model_path = f"{local_cache_dir}/{args.model_name}"
    tokenizer = None
    model = SentenceTransformer(model_path)

    model = model.to(device)

    index = ["train", "test"]
    
    for ind in index:
        sentence_embeddings = []
        if ind == "train":
            dataset, _, _, _ = get_dataset(dataset=args.dataset, load_from_local=True)
        elif ind == "test":
            _, _, dataset, _ = get_dataset(dataset=args.dataset, load_from_local=True)
            
        if args.model_name == "all-roberta-large-v1" or args.model_name == "sbert_large_mt_nlu_ru":
            sentence_embeddings = torch.zeros((len(dataset), 1024))
        elif args.model_name == "all-mpnet-base-v2":
            sentence_embeddings = torch.zeros((len(dataset),  768 ))
        elif args.model_name == "all-minilm-l6-v2":
            sentence_embeddings = torch.zeros((len(dataset),  384 ))
            
        for i, example in enumerate(tqdm(dataset)):
            text = f"Support: {example['support']}\nQuestion: {example['question']}"
            if(i == 1):
                print(text)
                
            if args.model_name == "all-roberta-large-v1" or args.model_name == "all-minilm-l6-v2" or args.model_name =="all-mpnet-base-v2" or args.model_name == "sbert_large_mt_nlu_ru":
                embeddings = model.encode([text])
                embeddings = torch.from_numpy(embeddings)
            else:
                embeddings = calculate_embeddings(text, model, tokenizer, device, last_layer=args.last_layer, if_avg=args.if_avg).detach().cpu()
            sentence_embeddings[i,:] = embeddings.squeeze()
        assert (args.if_avg and args.if_avgall) == False

        if args.model_name == "all-roberta-large-v1" or args.model_name == "all-minilm-l6-v2" or args.model_name =="all-mpnet-base-v2" or args.model_name == "sbert_large_mt_nlu_ru":
            torch.save(sentence_embeddings, "data/" + args.dataset +"/"+ args.dataset + "_" +ind+ "_"+args.model_name +".pt")

def calculate_embeddings(text, model, tokenizer, device, last_layer=True, if_avg=False):
    inputs = tokenizer(text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    if last_layer:
        if if_avg:
            embeddings = torch.mean(outputs.hidden_states[-1],dim=1)
        else:
            embeddings = outputs.hidden_states[-1][:,-1,:]
    else:
        if if_avg:
            embeddings = torch.mean(outputs.hidden_states[-2],dim=1)
        else:
            embeddings = outputs.hidden_states[-2][:,-1,:]
    return embeddings

def get_icl_embeddings(logger, test_task, metaicl_data, metaicl_model, train_data, dev_data, train_ids, n_class, seed, is_classification, last_layer=True, get_layer_embs=True, if_avg=False, if_avgall=False):
    all_embs = []
    acc = 0
    assert (if_avg and if_avgall) == False
    for ti, single_dev_data in enumerate(tqdm(dev_data)):
        train_data_list = [train_data[i] for i in train_ids]
        probs, perf, layer_embs, len_each_part = run(logger, test_task, metaicl_data, metaicl_model, 
                            train_data_list, [single_dev_data],
                            seed, is_classification, get_layer_embs=get_layer_embs)
        if last_layer:
            layer_embs = layer_embs[-1].numpy()
        else:
            layer_embs = layer_embs[-2].numpy()
        l_last = len_each_part[-1]
        if if_avg:
            final_emb = np.mean(layer_embs[-l_last:,:], axis=0).reshape(1,-1)
        elif if_avgall:
            final_emb = np.mean(layer_embs, axis=0).reshape(1,-1)
        else:
            final_emb = layer_embs[-1,:].reshape(1,-1)
        all_embs.append(final_emb)
        acc += perf
    all_embs_np = np.vstack(all_embs)
    all_embs_torch = torch.from_numpy(all_embs_np)
    print(all_embs_torch.size())
    print(f"Acc: {acc / len(dev_data)}")
    return all_embs_torch

def get_knn_embeddings(logger, test_task, metaicl_data, metaicl_model, train_data, dev_data, knn_ids, n_class, seed, is_classification, last_layer=True, get_layer_embs=True):
    all_embs = []
    acc = 0
    for ti, single_dev_data in enumerate(tqdm(dev_data)):
        train_data_list = [train_data[i] for i in knn_ids[ti]]
        probs, perf, layer_embs, len_each_part = run(logger, test_task, metaicl_data, metaicl_model, 
                            train_data_list, [single_dev_data],
                            seed, is_classification, get_layer_embs=get_layer_embs)
        if last_layer:
            layer_embs = layer_embs[-1].numpy()
        else:
            layer_embs = layer_embs[-2].numpy()
        final_emb = layer_embs[-1,:].reshape(1,-1)
        all_embs.append(final_emb)
        acc += perf
    all_embs_np = np.vstack(all_embs)
    all_embs_torch = torch.from_numpy(all_embs_np)
    print(all_embs_torch.size())
    print(f"Acc: {acc / len(dev_data)}")
    return all_embs_torch

def get_knn_data_idx(train_embs, dev_embs, k, seed, train_data, allow_single_label=False, use_similarity=True, balanced=False, random_shuffled=True):
    n_train, _ = train_embs.shape
    n_dev, _ = dev_embs.shape
    selected_ids = []
    train_emb_norm = np.sqrt(np.sum(train_embs * train_embs, axis=1))
    use_similarity = False
    use_only_nearest = False
    random_shuffled = True
    allow_single_label = False
    for i in range(n_dev):
        if use_similarity:
            dev_emb = dev_embs[i,:].reshape(-1,1)
            dot_prod = (train_embs @ dev_emb).reshape(-1)
            dist = -dot_prod / train_emb_norm
        else:
            dev_emb = dev_embs[i,:].reshape(1,-1)
            delta = train_embs - np.ones((n_train, 1)) @ dev_emb
            dist = np.sum(delta * delta, axis=1)
        assert dist.size == n_train
        if allow_single_label:
            selected_idx = dist.argsort()[:k]
            selected_idx = selected_idx.tolist()
            if random_shuffled:
                random.shuffle(selected_idx)
            selected_ids.append(selected_idx)
        elif balanced:
            idx_sorted = dist.argsort()
            idx_sorted = idx_sorted.tolist()
            selected_idx = [idx_sorted[0]]
            original_label = train_data[idx_sorted[0]]["output"]
            a0 = 1
            a1 = 0
            limit = k // 2
            for d in idx_sorted[1:]:
                if train_data[d]["output"] == original_label:
                    if a0 < limit:
                        selected_idx.append(d)
                        a0 += 1
                else:
                    if a1 < limit:
                        selected_idx.append(d)
                        a1 += 1
                if a0 + a1 == k:
                    break
            if random_shuffled:
                random.shuffle(selected_idx)
            selected_ids.append(selected_idx)
        else:
            idx_sorted = dist.argsort()
            idx_sorted = idx_sorted.tolist()
            if use_only_nearest:
                a = [idx_sorted[0]]
                b = idx_sorted[1:]
                random.shuffle(b)
                idx_sorted = a + b
            first_label = train_data[idx_sorted[0]]["output"]
            selected_idx = idx_sorted[:k]
            flag = False
            for id in selected_idx :
                if train_data[idx_sorted[id]]["output"] != first_label:
                    flag = True
                    break
            if flag == False:
                for id in idx_sorted:
                    if train_data[idx_sorted[id]]["output"] != first_label:
                        selected_idx[-1] = id
                        break
            if random_shuffled:
                random.shuffle(selected_idx)
            selected_ids.append(selected_idx)
    return selected_ids

def run(logger, task, id_metaicl_data, metaicl_model, train_data, id_dev_data, seed,
        is_classification, get_layer_embs=True):

    len_each_part = id_metaicl_data.tensorize(train_data, id_dev_data)
    if get_layer_embs:
        id_probs, layer_embs = metaicl_model.do_inference(id_metaicl_data, 1, verbose=False, get_layer_embs=get_layer_embs)
    else:
        id_probs = metaicl_model.do_inference(id_metaicl_data, 1, verbose=False)

    id_predictions, id_probs = metaicl_model.do_predict(id_metaicl_data, probs=id_probs)
    id_groundtruths = [dp["output"] for dp in id_dev_data]
    id_perf = id_metaicl_data.evaluate(id_predictions, id_groundtruths, is_classification)
    logger.info("Accuracy=%s" % (id_perf))
    id_probs = id_probs.detach().cpu().numpy()

    if get_layer_embs:
        return id_probs, id_perf, layer_embs, len_each_part
    return id_probs, id_perf

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--model_name", type=str, default="all-roberta-large-v1")

    parser.add_argument("--use_icl", default=None)

    parser.add_argument("--is_unlabel", action="store_true")
    parser.add_argument("--use_demonstrations", default=True, action="store_true")
    parser.add_argument('--is_classification', action='store_false')
    parser.add_argument("--trunc_method", type=str, default='middle', choices=['right', 'left', 'middle'])
    parser.add_argument("--dataset", type=str, default='contrast_boolq')
    parser.add_argument("--dev_dataset", type=str, default='boolq')
    parser.add_argument("--k", type=int, default=4)
    parser.add_argument("--max_length_per_example", type=int, default=128)
    parser.add_argument("--permute_fn_id", type=int, default=0)
    parser.add_argument("--n_comb", type=int, default=2500, 
        help="the number of distinct combinations of training examples")
    parser.add_argument("--n_perm", type=int, default=2, 
        help="the number of permutations under the same combination")
    parser.add_argument("--n_segments", type=int, default=250,
        help="divide prompts into different segments for multi-gpu speedup")
    parser.add_argument("--segment_id", type=int, default=0)
    parser.add_argument("--test_batch_size", type=int, default=64)
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--gpt2", type=str, default="gpt-j-6b")
    parser.add_argument("--log_file", default=None, type=str)
    parser.add_argument("--shuffle_seed", default=0, type=int)
    parser.add_argument("--last_layer", default=False, action="store_true")
    parser.add_argument("--use_embs", default="gpt-j-6b", type=str)
    parser.add_argument("--if_avg", default=False, action="store_true")
    parser.add_argument("--if_avgall", default=False, action="store_true")

    args = parser.parse_args()
    if args.use_icl:
        label_dir = 'unlabel' if args.is_unlabel else 'label'
        args.out_dir = os.path.join('icl_embs', args.gpt2, f"{args.dataset}_{args.dev_dataset}")

    handlers = [logging.StreamHandler()]
    if args.log_file is not None:
        handlers.append(logging.FileHandler(args.log_file))
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO,
                        handlers=handlers)
    logger = logging.getLogger(__name__)
    logger.info(args)

    main(logger, args)
