# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
print(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

import argparse
import random
import torch
import json
import logging
import numpy as np
import time
from datetime import datetime

from collections import Counter

from B_train_Topic_model.Topic_XICL.data import XICLData, XICLData_xnli
from B_train_Topic_model.Topic_XICL.model import XICLModel, XICLModel_xnli

from A_data_preprocess.util.data import load_data

def main(logger, args):
    max_length_per_example = args.max_length_per_example  # 512
    max_length = args.max_length
    if args.use_demonstrations:
        args.max_length = min(max_length_per_example * (args.test_k + 1), max_length)
    else:
        args.max_length = max_length_per_example

    max_length = args.max_length  # 512
    logger.info("batch_size=%d\tmax_length=%d\tmax_length_per_example=%d" % (
        args.batch_size, max_length, max_length_per_example))

    train_data = load_data(args.dataset, args.file_path, 'cluster_in_cross', src=args.src, tgt=args.src, seed=args.data_seed,
                           k=args.data_k, n_clusters=args.n_clusters, mode='train')

    train_counter = Counter()
    for dp in train_data:
        train_counter[dp["task"]] += 1
    if args.local_rank <= 0:
        for k, v in train_counter.items():
            logger.info("[Train] %s\t%d" % (k, v))
        logger.info("%s on %s (%d train)" % (args.method, args.dataset, len(train_counter)))

    ######### load tensorize data
    if args.dataset == "tydiqa":
        XICL_data = XICLData(logger, args.model_path, args.method, args.use_demonstrations,
                             args.test_k, max_length, max_length_per_example,
                             do_tensorize=args.do_tensorize,
                             tensorize_dir=args.tensorize_dir,
                             n_process=args.n_process, n_gpu=args.n_gpu,
                             local_rank=args.local_rank,
                             n_prefix_tokens=args.n_prefix_tokens,
                             task_counts=train_counter,
                             n_cluster=args.n_clusters)
    else:
        XICL_data = XICLData_xnli(logger, args.model_path, args.method, args.use_demonstrations,
                                  args.test_k, max_length, max_length_per_example,
                                  do_tensorize=args.do_tensorize,
                                  tensorize_dir=args.tensorize_dir,
                                  n_process=args.n_process, n_gpu=args.n_gpu,
                                  local_rank=args.local_rank,
                                  n_prefix_tokens=args.n_prefix_tokens,
                                  task_counts=train_counter,
                                  n_cluster=args.n_clusters)
    keyword = args.dataset
    XICL_data.tensorize_for_training(train_data, keyword=keyword,
                                     seed=args.data_seed)

    if args.do_tensorize:
        return

    ######## actual training part

    random.seed(args.train_seed)
    np.random.seed(args.train_seed)
    torch.manual_seed(args.train_seed)
    if torch.cuda.device_count() > 0:
        torch.cuda.manual_seed_all(args.train_seed)
    num_training_steps = args.num_training_steps

    if args.gradient_accumulation_steps > 1:
        num_training_steps *= args.gradient_accumulation_steps
    save_period = args.save_period  # 1000
    log_period = 10

    if args.no_masking:
        XICL_data.tensorized_inputs["token_type_ids"] = torch.ones_like(XICL_data.tensorized_inputs["input_ids"])
    XICL_data.print_tensorized_example()

    logger.info(args.out_dir)

    if args.local_rank <= 0 and not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    with open(os.path.join(args.out_dir, 'task2token.json'), 'w') as f:
        json.dump(XICL_data.prefix_token_ids, f, ensure_ascii=False)

    if args.dataset == "tydiqa":
        XICL_model = XICLModel(args.model_path, logger,
                               args.out_dir, args.fp16, args.local_rank, True, args.n_prefix_tokens,
                               prefix_embed_file=args.prefix_embed_file, task_counts=train_counter,
                               max_length=args.max_length, data=XICL_data)
        # XICL_model.to_device()
    else:
        XICL_model = XICLModel_xnli(args.model_path, logger,
                                    args.out_dir, args.fp16, args.local_rank, True, args.n_prefix_tokens,
                                    prefix_embed_file=args.prefix_embed_file, task_counts=train_counter,
                                    max_length=args.max_length, data=XICL_data)

    XICL_model.setup_optimizer(args.optimization, num_training_steps, args.lr,
                               args.weight_decay, args.warmup_steps)
    XICL_model.parallel()
    XICL_model.train()
    XICL_model.do_train(XICL_data, args.batch_size, num_training_steps, save_period, log_period,
                        gradient_accumulation_steps=args.gradient_accumulation_steps)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--do_tensorize", default=True, action="store_true")
    parser.add_argument("--tensorize_dir", type=str, default="tensorize_save_path")
    parser.add_argument("--n_gpu", type=int, default=1)
    parser.add_argument("--n_process", type=int, default=1)
    parser.add_argument("--n_prefix_tokens", type=int, default=10)
    parser.add_argument("--n_clusters", type=int, default=20)
    parser.add_argument("--max_length_per_example", type=int, default=256)
    parser.add_argument("--max_length", type=int, default=256)
    parser.add_argument("--src", type=str, default="en")

    parser.add_argument("--use_demonstrations", default=False, action="store_true")
    parser.add_argument("--prefix_embed_file", default=None, type=str)

    parser.add_argument("--dataset", type=str, default='xnli')
    # file_path
    parser.add_argument("--file_path", type=str, default='data_path')
    parser.add_argument("--split", type=str, default='train')
    parser.add_argument("--data_k", type=int, default=4)
    parser.add_argument("--test_k", type=int, default=4)
    parser.add_argument("--data_seed", type=int, default=32)
    parser.add_argument("--train_seed", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-6)
    parser.add_argument("--warmup_steps", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
    parser.add_argument("--num_training_steps", type=int, default=500)
    parser.add_argument("--save_period", type=int, default=1000)
    parser.add_argument("--weight_decay", type=float, default=0.03)
    parser.add_argument("--no_masking", default=False, action="store_true")

    parser.add_argument("--out_dir", type=str, default="out_path")
    parser.add_argument("--method", type=str, default="direct",
                        choices=["direct", "channel", "causal", "anti-causal"])
    parser.add_argument("--model_path", type=str, default="bigscience/bloomz-1b7/")

    parser.add_argument("--optimization", type=str, default="adamw")
    parser.add_argument("--fp16", default=False, action="store_true")
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")

    args = parser.parse_args()
    args.log_dir = "/".join(args.tensorize_dir.split["/"][:-1])
    handlers = [logging.StreamHandler()]
    log_file = os.path.join(args.log_dir, datetime.fromtimestamp(time.time()).isoformat())
    handlers = [logging.StreamHandler(), logging.FileHandler(log_file)]

    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO,
                        handlers=handlers)
    logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
    logger = logging.getLogger(__name__)
    logger.info(args)

    main(logger, args)
    args.do_tensorize = False

    main(logger, args)
