import os
import random

import numpy as np
from tqdm import tqdm

import pandas as pd
import pickle
import torch

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

task2lang = {
    'tydiqa': ['ar', 'bg', 'en', 'fi', 'id', 'ko', 'ru', 'sw', 'te'],
    'mlqa': ['ar', 'de', 'en', 'es', 'hi', 'vi', 'zh'],
    'xquad': ['ar', 'de', 'el', 'en', 'es', 'hi', 'ru', 'th', 'tr', 'vi', 'zh'],
    'mlsum': ['de', 'en', 'es', 'fr'],
    'mltg': ['de', 'en', 'es', 'fr'],
    'xnli': ['en', 'ar', 'bg', 'de', 'el', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
    'xcopa': ['en', 'et', 'ht', 'id', 'it', 'qu', 'sw', 'ta', 'th', 'tr', 'vi', 'zh'],
    'xcsqa': ["ar","de","en","es","fr","hi","it","ja","nl","pl","pt","ru","sw","ur","vi","zh"]
}


def save_file(test_final, dataset_name, file_path, set_up, src, tgt, seed, k, n_clusters, mode='test'):
    "SAVING FILE"
    file_path = os.path.join(file_path, 'process_data/{}/{}'.format(dataset_name, mode))
    os.makedirs(file_path, exist_ok=True)

    print('Saving Files')

    file_name = os.path.join(file_path,
                             '{}_src={}_tgt={}_k={}_seed={}_class={}.pkl'.format(set_up, src, tgt, k, seed, n_clusters))

    print(f'Saving tgt {tgt} src {src}')

    with open(file_name, 'wb') as handle:
        pickle.dump(test_final, handle, protocol=pickle.HIGHEST_PROTOCOL)


def have_file(dataset_name, file_path, set_up, src, tgt, seed, k, n_clusters, mode='test'):
    "SAVING FILE"
    file_path = os.path.join(file_path, 'process_data/{}/{}'.format(dataset_name, mode))
    os.makedirs(file_path, exist_ok=True)

    file_name = os.path.join(file_path,
                             '{}_src={}_tgt={}_k={}_seed={}_class={}.pkl'.format(set_up, src, tgt, k, seed, n_clusters))

    if os.path.isfile(file_name):
        print(f'Existing {tgt} src {src} file')
        return True
    return False


def load_data(dataset_name, file_path, set_up, src, tgt, seed, k, n_clusters, mode='test'):
    if src == None:
        src = 'all'
    # cluster_560m_100_hard_src=en_tgt=en_k=4_seed=100_class=20.pkl
    file_path = os.path.join(file_path, 'process_data/{}/{}'.format(dataset_name, mode))
    if not os.path.exists(file_path):
        print(file_path)
        raise ValueError('We dont have data for specified target language')
    try:
        file_name = os.path.join(file_path,
                                 '{}_src={}_tgt={}_k={}_seed={}_class={}.pkl'.format(set_up, src, tgt, k, seed, n_clusters))

        with open(file_name, 'rb') as handle:
            data = pickle.load(handle)
    except:
        file_name = os.path.join(file_path,
                                 '{}_src={}_s={}_k={}_class={}.pkl'.format(set_up, src, tgt, k,
                                                                                     n_clusters))

        with open(file_name, 'rb') as handle:
            data = pickle.load(handle)

    return data


def sample_from_dataframe(df, k, seed):
    df_train_k = df.sample(k, random_state=seed)

    return df_train_k.reset_index(drop=True)


def input_form_converter(test_df, demo_df=[]):
    if type(demo_df) != list:
        demos = demo_df.to_dict(orient='records')
    else:
        demos = demo_df
    test = test_df.to_dict(orient='records')

    """add options and demos in test set"""
    for d in test:
        d['demos'] = demos
    return test


def create_few_shots(dataset_name='tydiqa', file_path=None, src_l=["en"], k=16, n_clusters=100,
                         seeds=[32, 5, 232, 100, 42], set_up='src_is_cross', mode='test',sentence_model="sentence-transformers/distiluse-base-multilingual-cased-v1"):
    """

    :param dataset_name: name of the csv file
    :param src_l: a list of languages to sample from, if None passed we sample from all languages other than target
    :param k: number of samples to take default 16
    :param seeds: 5 seeds default to [32,5,232,100,42]
    :param set_up: name of the demonstration sampling technique
    :return: Saves a json list of dictionary of the form [ {'input':text,'demonstrations': {dict of k-shots of text-demonstration pairs}, 'output': label}]
    """
    if file_path is None:
        file_path = "./"
    train_set = file_path + f'{dataset_name}_train.csv'
    test_set = file_path + f'{dataset_name}_test.csv'

    if set_up in ['cluster_in_cross', 'sim_in_cross', 'cluster_and_sim','cluster_in_cross_new']:
        from sentence_transformers import SentenceTransformer, util
        embedder = SentenceTransformer(sentence_model)

    if set_up == "cluster_in_cross":
        for s in seeds:
            seed = s
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            df_train = pd.read_csv(train_set)
            if dataset_name == 'xnli':
                train_list = [i for i in range(10000)]
                df_train = df_train.iloc[train_list]
            df_test = pd.read_csv(test_set)

            print(f'For seed {s} and dataset {dataset_name} crating {k} few shot of {set_up} set_up')

            for src in src_l:
                df_train_l = df_train[df_train['language'].isin([src])].reset_index(drop=True)

                train_final = []
                corpus_input = df_train_l['input'].to_list()
                corpus_out = df_train_l['output'].to_list()
                corpus_id = df_train_l['id']

                corpus = df_train_l['input'].to_list()
                corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)

                from sklearn.cluster import KMeans
                km = KMeans(n_clusters=n_clusters, random_state=seed).fit(corpus_embeddings.cpu().numpy())
                center = km.cluster_centers_
                print(center)
                center_dict_train = {}

                if mode == 'train':
                    if have_file(dataset_name, file_path, set_up, src, src, seed, k, n_clusters, mode):
                        continue
                    labels = km.predict(corpus_embeddings.cpu().numpy())
                    for idx, label in enumerate(labels):
                        if u"t{}".format(label) not in center_dict_train.keys():
                            center_dict_train[u"t{}".format(label)] = []
                        center_dict_train[u"t{}".format(label)].append(idx)
                    """for key in center_dict.keys():
                        center_dict[key] = random.sample(center_dict[key], min(k, len(center_dict[key])))"""

                    for task_name, idxs in center_dict_train.items():
                        for idx in idxs:
                            demos = {
                                'id': corpus_id[idx],
                                'task': task_name,
                                'input': corpus_input[idx],
                                'output': corpus_out[idx],
                                'language': 'en'
                            }
                            train_final.append(demos)
                    save_file(train_final, dataset_name, file_path, set_up, src, src, seed, k, n_clusters, mode)
                else:
                    labels = km.predict(corpus_embeddings.cpu().numpy())
                    for idx, label in enumerate(labels):
                        if u"t{}".format(label) not in center_dict_train.keys():
                            center_dict_train[u"t{}".format(label)] = []
                        center_dict_train[u"t{}".format(label)].append(idx)
                tgt_l = task2lang[dataset_name]
                for l in tgt_l:
                    if have_file(dataset_name, file_path, set_up, src, l, seed, k, n_clusters, mode="test"):
                        continue
                    test_final = []
                    center_dict = {}
                    df_test_l = df_test[df_test['language'].isin([l])].reset_index(drop=True)
                    test_input = df_test_l['input'].to_list()
                    print(len(test_input))
                    test_out = df_test_l['output'].to_list()
                    test_id = df_test_l['id']

                    test_embeddings = embedder.encode(test_input, convert_to_tensor=True)
                    label_test = []
                    for i in range(test_embeddings.size()[0]):
                        query_embedding = test_embeddings[i]
                        cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
                        top_results = torch.topk(cos_scores, k=10)
                        corpus_emb = []
                        for idx in top_results[1]:
                            corpus_emb.append(corpus_embeddings[idx].cpu().numpy().tolist())
                        corpus_emb = np.array(corpus_emb, dtype=np.float32)
                        labels = km.predict(corpus_emb).tolist()
                        l_i = max(labels, key=labels.count)
                        label_test.append(l_i)
                    # labels = km.predict(test_embeddings.cpu().numpy())
                    for idx, label in enumerate(label_test):
                        if u"t{}".format(label) not in center_dict.keys():
                            center_dict[u"t{}".format(label)] = []
                        center_dict[u"t{}".format(label)].append(idx)

                    for task_name, idxs in center_dict.items():
                        sample_demos_idx = random.sample(center_dict_train[task_name], k)
                        for idx in idxs:
                            demos = {
                                'id': test_id[idx],
                                'task': task_name,
                                'input': test_input[idx],
                                'output': test_out[idx],
                                'language': l,
                                'demons':[{
                                        'input': corpus_input[i],
                                        'output': corpus_out[i],
                                        'language': 'en'
                                    } for i in sample_demos_idx]
                            }
                            test_final.append(demos)

                        # df['options'] = get_options(df, dataset_name)
                    save_file(test_final, dataset_name, file_path, set_up, src, l, seed, k, n_clusters, mode="test")
    if set_up == "cluster_in_cross_new":
        for s in seeds:
            seed = s
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            df_train = pd.read_csv(train_set)
            if dataset_name == 'xnli':
                train_list = [i for i in range(10000)]
                df_train = df_train.iloc[train_list]
            df_test = pd.read_csv(test_set)

            print(f'For seed {s} and dataset {dataset_name} crating {k} few shot of {set_up} set_up')

            for src in src_l:
                df_train_l = df_train[df_train['language'].isin([src])].reset_index(drop=True)

                train_final = []
                corpus_input = df_train_l['input'].to_list()
                corpus_out = df_train_l['output'].to_list()
                corpus_id = df_train_l['id']

                corpus = df_train_l['input'].to_list()
                corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)

                from sklearn.cluster import KMeans
                km = KMeans(n_clusters=n_clusters, random_state=seed).fit(corpus_embeddings.cpu().numpy())
                center = km.cluster_centers_
                print(center)
                center_dict_train = {}

                if mode == 'train':
                    if have_file(dataset_name, file_path, set_up, src, src, seed, k, n_clusters, mode):
                        continue
                    labels = km.predict(corpus_embeddings.cpu().numpy())
                    for idx, label in enumerate(labels):
                        if u"t{}".format(label) not in center_dict_train.keys():
                            center_dict_train[u"t{}".format(label)] = []
                        center_dict_train[u"t{}".format(label)].append(idx)
                    """for key in center_dict.keys():
                        center_dict[key] = random.sample(center_dict[key], min(k, len(center_dict[key])))"""

                    for task_name, idxs in center_dict_train.items():
                        for idx in idxs:
                            demos = {
                                'id': corpus_id[idx],
                                'task': task_name,
                                'input': corpus_input[idx],
                                'output': corpus_out[idx],
                                'language': 'en'
                            }
                            train_final.append(demos)
                    save_file(train_final, dataset_name, file_path, set_up, src, src, seed, k, n_clusters, mode)
                else:
                    labels = km.predict(corpus_embeddings.cpu().numpy())
                    for idx, label in enumerate(labels):
                        if u"t{}".format(label) not in center_dict_train.keys():
                            center_dict_train[u"t{}".format(label)] = []
                        center_dict_train[u"t{}".format(label)].append(idx)
                tgt_l = task2lang[dataset_name]
                for l in tgt_l:
                    if have_file(dataset_name, file_path, set_up, src, l, seed, k, n_clusters, mode="test"):
                        continue
                    test_final = []
                    center_dict = {}
                    df_test_l = df_test[df_test['language'].isin([l])].reset_index(drop=True)
                    test_input = df_test_l['input'].to_list()
                    print(len(test_input))
                    test_out = df_test_l['output'].to_list()
                    test_id = df_test_l['id']

                    test_embeddings = embedder.encode(test_input, convert_to_tensor=True)
                    label_test = km.predict(test_embeddings.cpu().numpy())
                    # labels = km.predict(test_embeddings.cpu().numpy())
                    for idx, label in enumerate(label_test):
                        if u"t{}".format(label) not in center_dict.keys():
                            center_dict[u"t{}".format(label)] = []
                        center_dict[u"t{}".format(label)].append(idx)

                    for task_name, idxs in center_dict.items():
                        sample_demos_idx = random.sample(center_dict_train[task_name], k)
                        for idx in idxs:
                            demos = {
                                'id': test_id[idx],
                                'task': task_name,
                                'input': test_input[idx],
                                'output': test_out[idx],
                                'language': l,
                                'demons':[{
                                        'input': corpus_input[i],
                                        'output': corpus_out[i],
                                        'language': 'en'
                                    } for i in sample_demos_idx]
                            }
                            test_final.append(demos)

                        # df['options'] = get_options(df, dataset_name)
                    save_file(test_final, dataset_name, file_path, set_up, src, l, seed, k, n_clusters, mode="test")
    if set_up == "cluster_and_sim":
        for s in seeds:
            seed = s
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            df_train = pd.read_csv(train_set)
            if dataset_name == 'xnli':
                train_list = [i for i in range(10000)]
                df_train = df_train.iloc[train_list]
            df_test = pd.read_csv(test_set)

            print(f'For seed {s} and dataset {dataset_name} crating {k} few shot of {set_up} set_up')

            for src in src_l:
                df_train_l = df_train[df_train['language'].isin([src])].reset_index(drop=True)

                train_final = []
                corpus_input = df_train_l['input'].to_list()
                corpus_out = df_train_l['output'].to_list()
                corpus_id = df_train_l['id']

                corpus = df_train_l['input'].to_list()
                corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)

                from sklearn.cluster import KMeans
                km = KMeans(n_clusters=n_clusters, random_state=seed).fit(corpus_embeddings.cpu().numpy())
                center = km.cluster_centers_
                print(center)
                center_dict_train = {}

                if mode == 'train':
                    if have_file(dataset_name, file_path, set_up, src, src, seed, k, n_clusters, mode):
                        continue
                    labels = km.predict(corpus_embeddings.cpu().numpy())
                    for idx, label in enumerate(labels):
                        if u"t{}".format(label) not in center_dict_train.keys():
                            center_dict_train[u"t{}".format(label)] = []
                        center_dict_train[u"t{}".format(label)].append(idx)
                    """for key in center_dict.keys():
                        center_dict[key] = random.sample(center_dict[key], min(k, len(center_dict[key])))"""

                    for task_name, idxs in center_dict_train.items():
                        for idx in idxs:
                            demos = {
                                'id': corpus_id[idx],
                                'task': task_name,
                                'input': corpus_input[idx],
                                'output': corpus_out[idx],
                                'language': 'en'
                            }
                            train_final.append(demos)
                    save_file(train_final, dataset_name, file_path, set_up, src, src, seed, k, n_clusters, mode)
                else:
                    labels = km.predict(corpus_embeddings.cpu().numpy())
                    for idx, label in enumerate(labels):
                        if u"t{}".format(label) not in center_dict_train.keys():
                            center_dict_train[u"t{}".format(label)] = []
                        center_dict_train[u"t{}".format(label)].append(idx)
                tgt_l = task2lang[dataset_name]
                for l in tgt_l:
                    if have_file(dataset_name, file_path, set_up, src, l, seed, k, n_clusters, mode="test"):
                        continue
                    test_final = []
                    center_dict = {}
                    df_test_l = df_test[df_test['language'].isin([l])].reset_index(drop=True)
                    test_input = df_test_l['input'].to_list()
                    print(len(test_input))
                    test_out = df_test_l['output'].to_list()
                    test_id = df_test_l['id']

                    test_embeddings = embedder.encode(test_input, convert_to_tensor=True)
                    label_test = []
                    labels_ = km.predict(test_embeddings.cpu().numpy())
                    for i in range(test_embeddings.size()[0]):
                        query_embedding = test_embeddings[i]
                        cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
                        top_results = torch.topk(cos_scores, k=10)
                        corpus_emb = []
                        for idx in top_results[1]:
                            corpus_emb.append(corpus_embeddings[idx].cpu().numpy().tolist())
                        corpus_emb = np.array(corpus_emb, dtype=np.float32)
                        labels = km.predict(corpus_emb).tolist()
                        l_i = max(labels, key=labels.count)
                        label_test.append(l_i)
                    # labels = km.predict(test_embeddings.cpu().numpy())
                    for idx, label in enumerate(label_test):
                        if u"t{}".format(label) not in center_dict.keys():
                            center_dict[u"t{}".format(label)] = []
                        center_dict[u"t{}".format(label)].append(idx)

                    for task_name, idxs in center_dict.items():
                        sample_demos_idx = random.sample(center_dict_train[task_name], k)
                        for idx in idxs:
                            demos = {
                                'id': test_id[idx],
                                'task': task_name,
                                'input': test_input[idx],
                                'output': test_out[idx],
                                'language': l,
                                'demons': [{
                                    'input': corpus_input[i],
                                    'output': corpus_out[i],
                                    'language': 'en'
                                } for i in sample_demos_idx]
                            }
                            test_final.append(demos)

                        # df['options'] = get_options(df, dataset_name)
                    save_file(test_final, dataset_name, file_path, set_up, src, l, seed, k, n_clusters, mode="test")
    else:
        for s in seeds:
            seed = s
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            df_train = pd.read_csv(train_set)
            if dataset_name == 'xnli':
                train_list = [i for i in range(10000)]
                df_train = df_train.iloc[train_list]
            df_test = pd.read_csv(test_set)

            print(f'For seed {s} and dataset {dataset_name} crating {k} few shot of {set_up} set_up')
            if mode == "train":
                languages = src_l
            else:
                languages = task2lang[dataset_name]
            for l in languages:
                if set_up == 'random':
                    for src in src_l:
                        if have_file(dataset_name, file_path, set_up, src, l, s, k, n_clusters, mode):
                            continue
                        """Sample demonstration from all language other than l"""
                        df_train_l = df_train[df_train['language'].isin([src])].reset_index(drop=True)
                        df_test_l = df_test[df_test['language'].isin([l])].reset_index(drop=True)

                        """Making sure we get same number of label of each kind"""

                        demo_df = sample_from_dataframe(df_train_l, k, seed=s)
                        assert len(demo_df) == k

                        """Converting into our standard form"""
                        test_final = input_form_converter(df_test_l, demo_df)
                        save_file(test_final, dataset_name, file_path, set_up, src, l, s, k, n_clusters, mode)
                elif set_up == 'sim_in_cross':
                    for src in src_l:
                        if have_file(dataset_name, file_path, set_up, src, l, s, k, n_clusters, mode):
                            continue
                        df_train_l = df_train[df_train['language'].isin([src])].reset_index(drop=True)
                        df_test_l = df_test[df_test['language'].isin([l])].reset_index(drop=True)

                        test_final = df_test_l.to_dict(orient='records')
                        corpus_input = df_train_l['input'].to_list()
                        corpus_out = df_train_l['output'].to_list()

                        if dataset_name == 'xcodah':
                            corpus = df_train_l['output'].to_list()
                        else:
                            corpus = df_train_l['input'].to_list()
                        corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)

                        top_k = k

                        for df in tqdm(test_final):

                            if dataset_name == 'xcodah':
                                query = df['output']
                            else:
                                query = df['input']
                            query_embedding = embedder.encode(query, convert_to_tensor=True)

                            cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
                            top_results = torch.topk(cos_scores, k=top_k)
                            demos = []

                            # print("\n\n======================\n\n")
                            # print("Query:", query)
                            # print(f"\nTop {k} most similar sentences in corpus:")
                            for score, idx in zip(top_results[0], top_results[1]):
                                # print(corpus[idx], "(Score: {:.4f})".format(score))

                                demos.append(
                                    {
                                        'input': corpus_input[idx],
                                        'output': corpus_out[idx],
                                        'score': score.cpu().detach().numpy().tolist(),
                                        'language': 'en'
                                    }
                                )
                            df['demos'] = demos

                            # df['options'] = get_options(df, dataset_name)
                        save_file(test_final, dataset_name, file_path, set_up, src, l, s, k, n_clusters, mode)
