import os
import random
import torch
import torch.distributed as dist
from torch.utils.data import ConcatDataset
from transformers import AutoTokenizer

import importlib
from collections import defaultdict

from inference import BB3InferenceAgent
from trainer import SelfLearner
from multitask_datasets import (
    MultiTaskCollator,
)

from hexa.utils.base_utils import (
    build_config,
    set_seed,
    hf_trainer_args_from_opt,
    print_config
)

from hexa.utils.metrics import MetricLogger
from hexa.collators import (
    EvalLoggingCallback
)

from hexa.utils.dist_utils import setup_multi_gpu
from hexa.utils.self_learn_utils import update_output_dir, get_e2e_datasets
from optimizer import set_optim_and_schedule

metric_config = {
    'exs': 'sum',
    'score': 'average',
    'bleu': 'average',
    'f1': 'average',
    'loss': 'average',
    'cos_sim': 'average',
}


if __name__ == "__main__":
    opt = build_config()
    if not opt.eval_base:
        update_output_dir(opt)
        opt.model.model_file = opt.trainer.output_dir
    setup_multi_gpu(opt)

    opt.server_port = opt.server_port if opt.server_port != -1 else os.environ['PORT1']
    rank = int(dist.get_rank()) if dist.is_initialized() else 0
    opt.dataset.search_server = 'http://127.0.0.1:{}'.format(int(opt.server_port) + rank)
    # opt.dataset.search_server = 'http://127.0.0.1:{}'.format(opt.server_port + dist.get_rank())
    opt.dataset.skip_retrieval_token = 'no_passages_used'

    set_seed(opt.seed)
    device = torch.device(opt.device_id)
    tokenizer = AutoTokenizer.from_pretrained(opt.hf_tokenizer_path)

    # dataset
    test_dataset = None
    eval_metric_logger = None
    callbacks = None
    rank = 0
    num_replicas = 1

    if opt.local_rank > -1:
        num_replicas = dist.get_world_size()
        rank = dist.get_rank()

    test_dataset = get_e2e_datasets(
        opt,
        tokenizer,
        datatype='test',
        max_entry_num=opt.max_num_entries,
        return_dict=True,
        shuffle=False,
        num_replicas=num_replicas,
        rank=rank,
        batch_size=opt.trainer.per_device_eval_batch_size,
    )
    # hf model
    set_seed(opt.seed)
    agent = BB3InferenceAgent(opt)
    tokenizer = agent.tokenizer
    collator = MultiTaskCollator(
        device=device,
        pad_token_id=tokenizer.pad_token_id,
        text_truncate=opt.dataset.truncate,
        return_episode_done=True,
    )
    # hf training args
    hf_trainer_args, extra_args = hf_trainer_args_from_opt(opt)
    print_config(hf_trainer_args, opt, convert=True)
    # optim and scheduler
    optimizers, scheduler = set_optim_and_schedule(hf_trainer_args, agent.model, extra_args)

    if type(test_dataset) is dict:
        for task_name, dataset in test_dataset.items():
            eval_metric_logger = MetricLogger('valid', metric_config)
            callbacks = [EvalLoggingCallback(eval_metric_logger, logfile_name=f'selflearn_eval-{task_name}_log')]

            # hf trainer
            trainer = SelfLearner(
                agent,
                rank,
                num_replicas,
                opt.num_loop,
                opt.num_bootstrap,
                is_finetune=True,
                args=hf_trainer_args,
                data_collator=collator,
                eval_dataset=dataset,
                tokenizer=tokenizer,
                optimizers=(optimizers, scheduler),
                callbacks=callbacks,
                eval_metric_logger=eval_metric_logger,
                name_space=opt.name_space,
                save_eval_samples=opt.save_eval_samples
            )

            if opt.trainer.fp16:
                agent.model.half()
            trainer.e2e_evaluate(task_name)
    else:
        eval_metric_logger = MetricLogger('valid', metric_config)
        callbacks = [EvalLoggingCallback(eval_metric_logger, logfile_name=f'selflearn_eval_log')]

        # hf trainer
        trainer = SelfLearner(
            agent,
            rank,
            num_replicas,
            opt.num_loop,
            opt.num_bootstrap,
            is_finetune=True,
            args=hf_trainer_args,
            data_collator=collator,
            eval_dataset=test_dataset,
            tokenizer=tokenizer,
            optimizers=(optimizers, scheduler),
            callbacks=callbacks,
            eval_metric_logger=eval_metric_logger,
            name_space=opt.name_space,
            save_eval_samples=opt.save_eval_samples
        )

        if opt.trainer.fp16:
            agent.model.half()
        trainer.e2e_evaluate()