import os
import copy
import random
import torch
import torch.distributed as dist
import math
from torch.utils.data import ConcatDataset

from transformers import AutoTokenizer

import importlib
from collections import defaultdict

from inference import BB3InferenceAgent
from trainer import CustomTrainer, SelfLearner
from multitask_datasets import (
    MultiTaskCollator,
    EpisodicDataset,
    WeightedConcatDataset,
)

from hexa.utils.base_utils import (
    build_config,
    set_seed,
    hf_trainer_args_from_opt,
    print_config
)

from hexa.utils.metrics import MetricLogger
from hexa.collators import (
    TensorBoardCallback,
    LoggingCallback,
    EvalLoggingCallback
)

from hexa.utils.dist_utils import setup_multi_gpu
from hexa.utils.self_learn_utils import (
    load_bootstrap_data,
    update_output_dir,
    get_model_file
)
from optimizer import set_optim_and_schedule
from train import metric_config, read_dataset


def set_steps(opt, num_bootstrap):
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    batch_size = opt.trainer.per_device_train_batch_size
    accum_size = opt.trainer.gradient_accumulation_steps
    num_samples = math.ceil(num_bootstrap / world_size)
    max_steps = math.ceil(num_samples / (batch_size * accum_size))
    max_steps *= opt.finetune_num_epoch
    print(f'MAX STEPS: {max_steps}, num_epoch: {opt.finetune_num_epoch}')
    logging_steps = max(max_steps // 100, 1)
    opt.trainer.max_steps = max_steps
    opt.trainer.logging_steps = logging_steps


def get_datasets(opt, tokenizer, datatype='train', **kwargs):
    if datatype == 'train':
        opt.dataset.datatype = datatype
        opt.dataset.episodic = True
    else:
        opt.dataset.datatype = datatype
        opt.dataset.episodic = False

    # multi-task dataset
    task_group = defaultdict(list)

    if datatype == 'train':
        use_pretrain_data = opt.scheme.startswith('ft') or opt.scheme.startswith('mix')
        bootstrap_dataset = ConcatDataset(load_bootstrap_dataset(opt, tokenizer))
        print(f'total bootstrap length: {len(bootstrap_dataset)}')
        if use_pretrain_data:
            datasets = [
                read_dataset(
                    dataset_class_names=v,
                    episodic_dataset_class=EpisodicDataset,
                    opt=opt,
                    tokenizer=tokenizer
                ) for k, v in task_group.items()
            ]
            dataset = ConcatDataset(sum(datasets, []))
            weights_dict = {
                'ft': [1, 0],
                'mix': [1, 1],
                'bt': [0, 1]
            }
            weights = weights_dict[opt.scheme]
            dataset = WeightedConcatDataset([dataset, bootstrap_dataset], weights=weights, **kwargs)
        else:
            dataset = bootstrap_dataset
        set_steps(opt, len(bootstrap_dataset))
    else:
        datasets = [
            read_dataset(
                dataset_class_names=v,
                episodic_dataset_class=EpisodicDataset,
                opt=opt,
                tokenizer=tokenizer
            ) for k, v in task_group.items()
        ]
        dataset_list = sum(datasets, [])
        dataset = ConcatDataset(dataset_list)

    return dataset


def load_bootstrap_dataset(opt, tokenizer):
    data = load_bootstrap_data(opt.name_space, opt.num_loop, opt.num_bootstrap)
    datasets = []
    sample_config = copy.deepcopy(opt.dataset)
    sample_config.fp16 = opt.trainer.fp16
    sample_config.seed = opt.seed
    tasks = opt.dataset.finetune_tasks
    task_module = importlib.import_module('hexa.utils.self_learn_utils')
    for task in tasks:
        load_func = getattr(task_module, task)
        dataset = EpisodicDataset(load_func(data, opt=opt), sample_config, tokenizer)
        if opt.local_rank == 0:
            print(f'task:{task}, data_len:{len(dataset)}')
        datasets.append(dataset)
    return datasets


if __name__ == "__main__":
    opt = build_config()
    if opt.num_loop > 0:
        opt.model.model_file = get_model_file(opt, load_prev=True)
        print(f'Model file path to be loaded:', opt.model.model_file)
    update_output_dir(opt)
    setup_multi_gpu(opt)

    opt.server_port = opt.server_port if opt.server_port != -1 else os.environ['PORT1']
    rank = int(dist.get_rank()) if dist.is_initialized() else 0
    opt.dataset.search_server = 'http://127.0.0.1:{}'.format(int(opt.server_port) + rank)
    opt.dataset.skip_retrieval_token = 'no_passages_used'

    set_seed(opt.seed)
    device = torch.device(opt.device_id)
    tokenizer = AutoTokenizer.from_pretrained(opt.hf_tokenizer_path)

    # dataset
    train_dataset, valid_dataset = None, None
    train_metric_logger, eval_metric_logger = None, None
    callbacks = None
    do_only_eval = not opt.trainer.do_train and opt.trainer.do_eval
    num_replicas = dist.get_world_size() if dist.is_initialized() else 1
    if opt.trainer.do_train:
        train_dataset = get_datasets(
            opt,
            tokenizer,
            datatype='train',
            num_replicas=num_replicas,
            rank=rank,
            batch_size=opt.trainer.per_device_train_batch_size
        )
        # print(f"RANK:{dist.get_rank()} of {dist.get_world_size()}, train data size:{len(train_dataset)}")
        train_metric_logger = MetricLogger('train', metric_config)
        callbacks = [TensorBoardCallback(), LoggingCallback(train_metric_logger)]

    if opt.trainer.do_eval:
        valid_dataset = get_datasets(
            opt,
            tokenizer,
            datatype='valid',
            batch_size=opt.trainer.per_device_eval_batch_size
        )
        eval_metric_logger = MetricLogger('valid', metric_config)

        if do_only_eval:
            callbacks = [EvalLoggingCallback(eval_metric_logger)]

    collator = MultiTaskCollator(
        device=device,
        pad_token_id=tokenizer.pad_token_id,
        text_truncate=opt.dataset.truncate,
    )

    # hf model
    set_seed(opt.seed)
    agent = BB3InferenceAgent(opt)
    tokenizer = agent.tokenizer
    agent.model.train()

    # hf training args
    hf_trainer_args, extra_args = hf_trainer_args_from_opt(opt)
    print_config(hf_trainer_args, opt, convert=True)

    # optim and scheduler
    optimizers, scheduler = set_optim_and_schedule(hf_trainer_args, agent.model, extra_args)

    # hf trainer
    trainer = SelfLearner(
        agent,
        rank,
        num_replicas,
        opt.num_loop,
        opt.num_bootstrap,
        is_finetune=True,
        args=hf_trainer_args,
        data_collator=collator,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        tokenizer=tokenizer,
        optimizers=(optimizers, scheduler),
        callbacks=callbacks,
        train_metric_logger=train_metric_logger,
        eval_metric_logger=eval_metric_logger,
        name_space=opt.name_space,
    )

    random.seed(opt.seed + rank)

    if do_only_eval:
        trainer.evaluate()
    else:
        train_result = trainer.train()
        trainer.save_model()
        metrics = train_result.metrics

        metrics["train_samples"] = len(train_dataset)
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()