# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import random
import sys

import argparse
import torch
import json
import logging
import time
import numpy as np
from datetime import datetime
from pathlib import Path

from collections import Counter

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
print(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from B_train_Topic_model.Topic_XICL.data import XICLData_infer, XICLData_xnli_infer
from B_train_Topic_model.Topic_XICL.model import XICLModel, XICLModel_xnli
from A_data_preprocess.util.data import load_data, save_file, task2lang


def main(logger, args):
    assert (args.dataset is not None and args.task is None) or (args.dataset is None and args.task is not None)
    if args.dataset != "tydiqa":
        XICL_Data = XICLData_xnli_infer
        XICL_Model = XICLModel_xnli
    else:
        XICL_Data = XICLData_infer
        XICL_Model = XICLModel
    args.lang = task2lang[args.dataset]
    add_newlines = not args.model_path.startswith("gpt2")

    task_counts = None
    if args.prefix_embed_file is not None:
        model_dir = Path(args.prefix_embed_file).parent.absolute()
        if os.path.exists(os.path.join(model_dir, 'task2token.json')):
            with open(os.path.join(model_dir, 'task2token.json')) as f:
                task_counts = json.load(f)

    concept_model = None
    if "most_similar" in args.prior:
        import sentence_transformers.util as sent_util
        from sentence_transformers import SentenceTransformer
        embedding_model = SentenceTransformer('distiluse-base-multilingual-cased-v1')
        # embedding_model = sentence_transformers.SentenceTransformer(args.embedding_model)
        embedding_model.cuda()
        embedding_model.eval()

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    # setup hyperparams for data
    max_length_per_example = args.max_length_per_example
    if args.use_demonstrations:
        max_length = min(max_length_per_example * args.k, args.max_length)
    else:
        max_length = max_length_per_example

    logger.info("batch_size=%d\tmax_length=%d\tmax_length_per_example=%d" % (
        args.test_batch_size, max_length, max_length_per_example))

    seed = int(args.seed)
    np.random.seed(int(seed))

    train_data = load_data(args.dataset, args.file_path, 'cluster_in_cross', args.src, args.src, seed, k=args.k, n_clusters=args.n_clusters,
                           mode='train')
    if args.use_demonstrations:
        demonstration_data = random.sample(train_data, args.k)
    else:
        demonstration_data = None
    if "most_similar" in args.prior:
        all_embeddings = embedding_model.encode([d["input"] for d in train_data])
    for lang in args.lang:
        print(lang)

        dev_data = load_data(args.dataset, file_path=args.file_path, set_up='cluster_in_cross', tgt=lang, src=args.src, seed=seed, k=args.k, n_clusters=args.n_clusters,
                             mode='test')
        print(len(dev_data))

        train_counter = Counter()
        dev_counter = Counter()
        for dp in train_data:
            train_counter[dp["task"]] += 1
        for dp in dev_data:
            dev_counter[dp["task"]] += 1
        for k, v in train_counter.items():
            logger.info("[Train] %s\t%d" % (k, v))
        for k, v in dev_counter.items():
            logger.info("[Dev] %s\t%d" % (k, v))


        logger.info("%s on %s (%d train, %d dev)" % (args.method, args.task, len(train_counter), len(dev_counter)))

        for test_task in dev_counter:
            curr_dev_data = [dp for dp in dev_data if dp["task"] == test_task]
            curr_dev_id = [id for id, dp in enumerate(dev_data) if dp["task"] == test_task]
            assert len(curr_dev_data) > 0
            if args.test_size < len(curr_dev_data) and args.split == "test":
                subsample_ids = np.random.choice(len(curr_dev_data), args.test_size, replace=False)
                curr_dev_data = np.array(curr_dev_data)[subsample_ids].tolist()

            _train_data = [dp for dp in train_data if dp["task"] == test_task]
            _train_data_idx = [id for id, dp in enumerate(train_data) if dp["task"] == test_task]
            args.train_size = min(len(_train_data), args.train_size)
            if args.train_size > 0:
                subsample_ids = np.random.choice(len(_train_data), args.train_size, replace=False)
                curr_train_data = np.array(_train_data)[subsample_ids].tolist()
                curr_train_data_id = np.array(_train_data_idx)[subsample_ids].tolist()
            else:
                curr_train_data = _train_data
                curr_train_data_id = _train_data_idx

            if "most_similar" in args.prior:
                all_embedding_dir = os.path.join(args.embedding_dir, test_task,
                                                 args.embedding_model)
                all_embedding_file = os.path.join(all_embedding_dir, 'train.npy')
                if os.path.isfile(all_embedding_file):
                    all_embeddings = np.load(all_embedding_file)[subsample_ids]
                else:
                    os.makedirs(all_embedding_dir, exist_ok=True)
                    all_embeddings = embedding_model.encode(
                        [d["input"] for d in _train_data])
                    np.save(all_embedding_file, all_embeddings)
                    all_embeddings = all_embeddings[subsample_ids]

                dev_embedding_dir = os.path.join(args.embedding_dir, test_task,
                                                 args.embedding_model)
                dev_embedding_file = os.path.join(dev_embedding_dir, 'dev.npy')
                if os.path.isfile(dev_embedding_file):
                    dev_embeddings = np.load(dev_embedding_file)
                else:
                    os.makedirs(dev_embedding_dir, exist_ok=True)
                    dev_embeddings = embedding_model.encode(
                        [d["input"] for d in curr_dev_data])
                    np.save(dev_embedding_file, dev_embeddings)

                sims = sent_util.cos_sim(dev_embeddings, all_embeddings)
                sims = sims.cpu().detach().numpy()
                dev_train_exp_sims = np.exp(sims / args.similarity_temperature)

            priors = set(args.prior)
            use_difficulty = len(set(["easiest", "hardest"]).intersection(priors)) > 0

            if use_difficulty:
                all_log_ps = []

                # for task in task_counts:
                task = test_task
                if args.train_size > 0:
                    concept_dir = os.path.join(args.concept_dir, args.prefix_embed_file.split("/")[-1].split(".")[0],
                                               f"{test_task}-train-{seed}_use_demonstrations" if args.use_demonstrations else f"{test_task}-dev-{seed}")
                else:
                    concept_dir = os.path.join(args.concept_dir, args.prefix_embed_file.split("/")[-1].split(".")[0],
                                               f"{test_task}-train-{seed}-t_use_demonstrations" if args.use_demonstrations else f"{test_task}-dev-{seed}-t")

                if os.path.exists(os.path.join(concept_dir, f'{task}-nll.npy')):
                    logger.info("loading saved concept likelihoods")
                    all_nll = np.load(os.path.join(concept_dir, f'{task}-nll.npy'))
                    # gt_labels = np.load(os.path.join(concept_dir, f'{task}-gt.npy'))
                else:
                    assert args.prefix_embed_file is not None
                    logger.info("start running soft prefix model")
                    start_time = time.time()
                    concept_data = XICL_Data(logger, args.model_path,
                                             args.method, args.use_demonstrations, args.use_instruction, args.k,
                                             max_length, max_length_per_example,
                                             add_newlines=add_newlines,
                                             n_prefix_tokens=args.n_prefix_tokens,
                                             prefix=False, task_counts=task_counts,
                                             prefix_token_ids=task_counts, task=task, n_cluster=args.n_clusters)
                    if concept_model is None:
                        prefix_token_ids = concept_data.prefix_token_ids
                        concept_model = XICL_Model(args.model_path,
                                                   logger, args.out_dir, soft_prefix=True,
                                                   n_tokens=args.n_prefix_tokens,
                                                   prefix_embed_file=args.prefix_embed_file,
                                                   task_counts=task_counts, data=concept_data)
                        concept_model.cuda()
                        concept_model.eval()
                        concept_model.prefix_token_ids = prefix_token_ids

                    results = run(concept_data, concept_model, demonstration_data, curr_train_data, False)
                    all_nll = results[-1] if not isinstance(results, list) else results
                    if not os.path.exists(concept_dir):
                        os.makedirs(concept_dir)
                    np.save(os.path.join(concept_dir, f'{task}-nll.npy'), all_nll)

                    logger.info(
                        f"time use for computing {len(curr_train_data)} examples: {time.time() - start_time}")

                torch.cuda.empty_cache()
                log_p = []
                for _nll in all_nll:
                    log_p.append(-_nll / args.concept_temperature)

                if task == test_task:
                    opt_log_p = log_p
                all_log_ps.append(log_p)

                """z = 0
                for log_p in all_log_ps:
                    z += np.exp(log_p)
                calibrated_p = np.exp(opt_log_p - np.log(z))   # 相对概率
                difficulties = 1 - calibrated_p"""

                opt_p = np.exp(opt_log_p)
                difficulties = 1 - opt_p

                difficulties = np.array(difficulties)
                assert len(difficulties) == len(curr_train_data)
                print(difficulties)

                sorted_diff = np.sort(difficulties)
                min_diff = sorted_diff[0]
                logger.info(f"min difficulty: {min_diff}")
                max_diff = sorted_diff[-1]
                logger.info(f"max difficulty: {max_diff}")
                logger.info(f"average difficulty: {np.mean(difficulties)}")

            if "hardest" in args.prior or "easiest" in args.prior:

                sorted_ids = np.argsort(difficulties)
                if "hardest" in args.prior:
                    if args.use_similarity:
                        demo_ids = sorted_ids[-20:]
                    else:
                        demo_ids = sorted_ids[-args.k:]
                else:
                    if args.use_similarity:
                        demo_ids = sorted_ids[:20]
                    else:
                        demo_ids = sorted_ids[:args.k]

            elif "most_similar" in args.prior:
                sorted_sims = np.argsort(dev_train_exp_sims)
                demo_ids = sorted_sims[:, -args.k:].reshape(-1)

            else:
                demo_ids = np.random.choice(len(curr_train_data),
                                            args.test_size * args.k)

            demonstrations = []
            if args.use_similarity:
                for id, dp in zip(curr_dev_id, curr_dev_data):
                    demo_ids_train = np.array(curr_train_data_id)[demo_ids].tolist()
                    all_embeddings_ = all_embeddings[demo_ids_train]
                    dev_embeddings = embedding_model.encode(dp["input"])
                    cos_scores = sent_util.cos_sim(dev_embeddings, all_embeddings_)[0]
                    top_results = torch.topk(cos_scores, k=args.k)
                    demo_ids_dp = [demo_ids_train[j] for j in top_results[1]]
                    demonstrations_dp = []
                    for i in demo_ids_dp:
                        demonstrations_dp.append(train_data[i])
                    demonstrations.append(demonstrations_dp)
                    dev_data[id]['demos'] = demonstrations_dp
            else:
                for id, dp in zip(curr_dev_id, curr_dev_data):
                    demonstrations_dp = []
                    demo_ids_dp = np.array(curr_train_data_id)[demo_ids].tolist()
                    for i in demo_ids_dp:
                        demonstrations_dp.append(train_data[i])

                    demonstrations.append(demonstrations_dp)
                    dev_data[id]['demos'] = demonstrations_dp
        if args.use_demonstrations:  # ,c=args.n_prefix_tokens
            save_file(dev_data, args.dataset, args.file_path,
                      u'{prefix}_{d}_d{a}_{b}'.format(prefix=args.data_name, b=seed,
                                                      a="" if args.prior == ["easiest"] else "_hard", d=str(
                              args.prefix_embed_file.split("/")[-1].split(".")[0].split("-")[-1])), args.src, lang,
                      seed, args.k,
                      args.n_clusters,
                      mode='test')
        else:
            save_file(dev_data, args.dataset, args.file_path,
                      u'{prefix}{a}_{d}_{b}'.format(prefix=args.data_name, b=seed,
                                                    a="" if args.prior == ["easiest"] else "_hard", d=str(
                              args.prefix_embed_file.split("/")[-1].split(".")[0].split("-")[-1])), args.src, lang,
                      seed, args.k, args.n_clusters,
                      mode='test')


def permutation(lst):
    if len(lst) == 0:
        return []
    if len(lst) == 1:
        return [lst]

    l = []
    for i in range(len(lst)):
        m = lst[i]
        remLst = lst[:i] + lst[i + 1:]
        for p in permutation(remLst):
            l.append([m] + p)
    return l


def run(XICL_data, XICL_model, train_data, dev_data, return_all=False, use_task=False, use_output=False):
    XICL_data.tensorize(train_data, dev_data, instruction="The task is to choose the best prefix for a given input from the following options:\n",use_task=use_task, use_output=use_output)  # The task is to choose the best prefix for a given input from the following options:\n
    XICL_data.print_tensorized_example()
    losses = XICL_model.do_inference(XICL_data, args.test_batch_size)
    assert len(losses) == len(XICL_data)
    if return_all:
        predictions, all_nlls, gt_labels = XICL_model.do_predict(
            XICL_data, losses=losses, return_nll=True)

        groundtruths = XICL_data.metadata
        results = XICL_data.evaluate_mrc(predictions, 'en')
        print(results)
        return (results['exact'], results['f1'], predictions, groundtruths, all_nlls)
    else:
        return (losses)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--use_demonstrations", default=False, action="store_true")
    parser.add_argument('--use_instruction', default=False, action='store_true')
    parser.add_argument("--n_prefix_tokens", type=int, default=15)
    parser.add_argument("--n_clusters", type=int, default=20)
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--max_length_per_example", type=int, default=512)

    parser.add_argument("--prior", type=str, nargs='+', default=["hardest"],
                        choices=["most_similar", "easiest", "hardest"])

    parser.add_argument("--log_dir", default='output/logs', type=str)
    parser.add_argument("--out_dir", type=str, default='output/bloomz')
    parser.add_argument("--load_dir", default=None, type=str)
    parser.add_argument("--use_similarity", default=False)
    parser.add_argument("--concept_dir", default='', type=str)
    parser.add_argument("--file_path", type=str, default='')
    parser.add_argument("--prefix_embed_file", default='', type=str)

    parser.add_argument("--task", type=str, default=None)
    parser.add_argument("--dataset", type=str, default='tydiqa')
    parser.add_argument("--k", type=int, default=4)
    parser.add_argument("--seed", type=str, default="32")
    parser.add_argument("--src", type=str, default="en")
    parser.add_argument("--data_name", type=str, default="cluster_1b7")

    parser.add_argument("--test_batch_size", type=int, default=1)

    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--method", type=str, default="direct",
                        choices=["direct", "channel"])
    parser.add_argument("--model_path", type=str, default="bigscience/bloomz-1b7/")
    parser.add_argument("--api", type=str, default=None)

    parser.add_argument("--test_size", type=int, default=100000)
    parser.add_argument("--train_size", type=int, default=0)
    parser.add_argument("--embedding_dir", type=str, default='embedding/')
    parser.add_argument("--embedding_model", type=str, default='all-mpnet-base-v2',
                        choices=['all-mpnet-base-v2'])
    parser.add_argument("--similarity_temperature", type=float, default=0.1)
    parser.add_argument("--concept_temperature", type=float, default=50.0)

    args = parser.parse_args()

    os.makedirs(args.log_dir, exist_ok=True)
    log_file = os.path.join(args.log_dir, datetime.fromtimestamp(time.time()).isoformat())
    handlers = [logging.StreamHandler(), logging.FileHandler(log_file)]

    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO,
                        handlers=handlers)
    logger = logging.getLogger(__name__)
    logger.info(args)
    main(logger, args)
