import copy
import random
import torch
import torch.distributed as dist
from torch.utils.data import ConcatDataset

from transformers import AutoTokenizer

import importlib
from collections import defaultdict

from model import HFR2C2Model
from trainer import CustomTrainer
from multitask_datasets import (
    MultiTaskCollator,
    EpisodicDataset,
    WeightedConcatDataset,
)

from hexa.utils.base_utils import (
    build_config,
    set_seed,
    hf_trainer_args_from_opt,
    print_config
)

from hexa.utils.metrics import MetricLogger
from hexa.collators import (
    TensorBoardCallback, 
    LoggingCallback, 
    EvalLoggingCallback
)

from hexa.utils.dist_utils import setup_multi_gpu
from optimizer import set_optim_and_schedule


metric_config = {
    'exs': 'sum',
    'loss': 'average',
    'ppl': 'ppl',
    'token_acc': 'average',
    'token_em': 'average',
    'gnorm': 'average',
    'clip': 'average',
    'ctpb': 'average',  # context tokens per batch
    'ltpb': 'average',  # label tokens per batch
    'expb': 'average',  # examples per batch
    'exps': 'timer',  # examples per second
    'ups': 'timer',  # updates per second
    'total_train_updates': 'fixed',
    'gpu_mem': 'average'
}


def read_dataset(
        dataset_class_names,
        episodic_dataset_class,
        opt,
        tokenizer,
        **kwargs
):
    datasets = []
    for cls_name in dataset_class_names:
        sample_config = copy.deepcopy(opt.dataset)
        sample_config.fp16 = opt.trainer.fp16
        sample_config.seed = opt.seed
        dataset = cls_name(
            sample_config,
            tokenizer=None,
            use_cache=sample_config.use_cache
        )
        datasets.append(
            episodic_dataset_class(dataset, sample_config, tokenizer, **kwargs)
        )
    return datasets


def get_datasets(opt, tokenizer, datatype='train', **kwargs):
    if datatype == 'train':
        # FIXME: temporalily used valid data for training
        opt.dataset.datatype = 'train'
        opt.dataset.episodic = True
        tasks = opt.dataset.tasks
    else:
        opt.dataset.datatype = 'valid'
        opt.dataset.episodic = False
        if opt.dataset.use_valid_subset:
            tasks = opt.dataset.eval_tasks
            assert tasks
        else:
            tasks = opt.dataset.tasks

    # multi-task dataset
    task_group = defaultdict(list)
    weights = []
    weights_group = defaultdict(list)
    
    # tasks: (list of str) or (list of dict (group) of list of str)
    for task in tasks:
        task_group_name = list(task.keys())[0]
        module_name = f'hexa.teachers.{task_group_name}'
        task_module = importlib.import_module(module_name)
        for t in task[task_group_name]:
            if type(t) == dict:
                key = list(t.keys())[0]
                if key == 'weights':
                    weights.extend(t[key])
                    weights_group[task_group_name] = t[key]
            else:
                task_group[task_group_name].extend([getattr(task_module, t)])

    if datatype == 'train':
        datasets = [
            read_dataset(
                dataset_class_names=v,
                episodic_dataset_class=EpisodicDataset,
                opt=opt,
                tokenizer=tokenizer
            ) for k, v in task_group.items()
        ]
        dataset_list = sum(datasets, [])
        dataset = ConcatDataset(dataset_list)
    else:
        datasets = [
            read_dataset(
                dataset_class_names=v,
                episodic_dataset_class=EpisodicDataset,
                opt=opt,
                tokenizer=tokenizer
            ) for k, v in task_group.items()
        ]
        dataset_list = sum(datasets, [])
        dataset = ConcatDataset(dataset_list)

    return dataset


if __name__ == "__main__":
    opt = build_config()
    setup_multi_gpu(opt)
    set_seed(opt.seed)
    device = torch.device(opt.device_id)
    tokenizer = AutoTokenizer.from_pretrained(opt.hf_tokenizer_path)

    # dataset
    train_dataset, valid_dataset = None, None
    train_metric_logger, eval_metric_logger = None, None
    callbacks = None
    do_only_eval = not opt.trainer.do_train and opt.trainer.do_eval
    rank = 0
    num_replicas = 1
    if opt.trainer.do_train:
        if opt.local_rank > -1:
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()

        train_dataset = get_datasets(
            opt, 
            tokenizer, 
            datatype='train',
            num_replicas=num_replicas, 
            rank=rank,
            batch_size=opt.trainer.per_device_train_batch_size
        )
        # print(f"RANK:{dist.get_rank()} of {dist.get_world_size()}, train data size:{len(train_dataset)}")
        train_metric_logger = MetricLogger('train', metric_config)
        callbacks = [TensorBoardCallback(), LoggingCallback(train_metric_logger)]

    if opt.trainer.do_eval:
        valid_dataset = get_datasets(
            opt, 
            tokenizer, 
            datatype='valid',
            batch_size=opt.trainer.per_device_eval_batch_size
        )
        eval_metric_logger = MetricLogger('valid', metric_config)
        
        if do_only_eval:
            callbacks = [EvalLoggingCallback(eval_metric_logger)]

    collator = MultiTaskCollator(
        device=device,
        pad_token_id=tokenizer.pad_token_id,
        text_truncate=opt.dataset.truncate,
    )

    # hf model
    model = HFR2C2Model(
        opt=opt
    )

    print("Number of parameters:", "{:,}".format(sum(p.numel() for p in model.parameters())))

    # hf training args
    hf_trainer_args, extra_args = hf_trainer_args_from_opt(opt)
    print_config(hf_trainer_args, opt, convert=True)

    # optim and scheduler
    optimizers, scheduler = set_optim_and_schedule(hf_trainer_args, model, extra_args)

    # hf trainer
    trainer = CustomTrainer(
        args=hf_trainer_args,
        data_collator=collator,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        model=model,
        tokenizer=tokenizer,
        optimizers=(optimizers, scheduler),
        callbacks=callbacks,
        train_metric_logger=train_metric_logger,
        eval_metric_logger=eval_metric_logger,
        # callbacks=[EpochLossCallback(), TokenizerSaveCallback()],
    )
    
    random.seed(opt.seed + rank)

    if do_only_eval:
        trainer.evaluate()
    else:
        train_result = trainer.train()
        trainer.save_model()
        metrics = train_result.metrics

        metrics["train_samples"] = len(train_dataset)
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()