import os
import random
import torch
import torch.distributed as dist
from transformers import AutoTokenizer

from inference import BB3InferenceAgent
from train import metric_config
from multitask_datasets import (
    MultiTaskCollator,
)
from optimizer import set_optim_and_schedule
from trainer import SelfLearner
from hexa.utils.base_utils import (
    build_config,
    set_seed,
    hf_trainer_args_from_opt,
    print_config
)

from hexa.utils.metrics import MetricLogger
from hexa.collators import (
    TensorBoardCallback,
    LoggingCallback
)

from hexa.utils.dist_utils import setup_multi_gpu
from hexa.utils.self_learn_utils import update_output_dir, get_e2e_datasets


if __name__ == "__main__":
    # opt
    opt = build_config()
    if opt.num_loop > 0:
        update_output_dir(opt, load_prev=True)
        opt.model.model_file = opt.trainer.output_dir
    setup_multi_gpu(opt)
    opt.server_port = opt.server_port if opt.server_port != -1 else os.environ['PORT1']
    rank = int(dist.get_rank()) if dist.is_initialized() else 0
    opt.dataset.search_server = 'http://127.0.0.1:{}'.format(int(opt.server_port) + rank)
    opt.dataset.skip_retrieval_token = 'no_passages_used'
    device = torch.device(opt.device_id)
    tokenizer = AutoTokenizer.from_pretrained(opt.hf_tokenizer_path)

    # dataset
    train_dataset = None
    train_metric_logger = None
    callbacks = None
    num_replicas = dist.get_world_size() if dist.is_initialized() else 1
    # bootstrap with train_dataset
    train_dataset = get_e2e_datasets(
        opt,
        tokenizer,
        datatype='train',
        num_replicas=num_replicas,
        rank=rank,
        batch_size=opt.trainer.per_device_train_batch_size
    )
    # print(f"RANK:{dist.get_rank()} of {dist.get_world_size()}, train data size:{len(train_dataset)}")
    train_metric_logger = MetricLogger('train', metric_config)
    callbacks = [TensorBoardCallback(), LoggingCallback(train_metric_logger)]

    collator = MultiTaskCollator(
        device=device,
        pad_token_id=tokenizer.pad_token_id,
        text_truncate=opt.dataset.truncate,
        return_episode_done=True
    )

    # init agent
    set_seed(opt.seed)
    agent = BB3InferenceAgent(opt, tokenizer=tokenizer)

    hf_trainer_args, extra_args = hf_trainer_args_from_opt(opt)
    print_config(hf_trainer_args, opt, convert=True)

    # optim and scheduler
    optimizers, scheduler = set_optim_and_schedule(hf_trainer_args, agent.model, extra_args)

    # hf trainer
    trainer = SelfLearner(
        agent,
        rank,
        num_replicas,
        opt.num_loop,
        opt.num_bootstrap,
        args=hf_trainer_args,
        data_collator=collator,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        optimizers=(optimizers, scheduler),
        callbacks=callbacks,
        train_metric_logger=train_metric_logger,
        name_space=opt.name_space,
        use_cos_sim=opt.use_cos_sim,
        bt_inc_rate=opt.bt_inc_rate,
        base_threshold=opt.base_threshold
    )

    random.seed(opt.seed + rank)

    if opt.trainer.fp16:
        agent.model.half()
    trainer.bootstrap()