import logging
import os
import time
import warnings
from calendar import c

import hydra
import torch

from accelerate import (
    Accelerator,
    DataLoaderConfiguration,
    DistributedDataParallelKwargs,
)

from hydra.core.config_store import ConfigStore

from misc_utils import time_formatter
from model.utils import (
    kldiv_activation,
    normalize_action,
    tanh_activation,
    TrainCollate,
)

from omegaconf import MISSING, OmegaConf
from optimizer.optimizer_lib import OptimizerConfig
from rl.mcts_policy import AlphaZeroConfig
from scheduler.scheduler_lib import SchedulerConfig
from torch.distributed import destroy_process_group, init_process_group
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    from torchmetrics.aggregation import MeanMetric

from trainer.pretrainer import AcceleratorTrainer, BasicTrainer, DDPTrainer, TrainConfig

cs = ConfigStore.instance()
cs.store(name="AlphaZero_config", node=AlphaZeroConfig, group="policy")
cs.store(name="train_base_config", node=TrainConfig)


def ddp_setup():
    os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["LOCAL_RANK"]
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    torch.cuda.empty_cache()


@hydra.main(version_base=None, config_path="../conf", config_name="train_config")
def launch_pretraining(cfg: TrainConfig) -> None:
    log = logging.getLogger(__name__)

    if cfg.use_accelerator and cfg.use_ddp:
        raise ValueError("Cannot have both use_ddp and use_accelerator set to True")

    # Make implicit adjustaments to the config
    cfg.dataset.return_action_mask = cfg.return_action_mask
    cfg.dataset.const_node = cfg.const_node
    cfg.dataset.reward_type = cfg.reward_type
    if cfg.model.is_causal:
        cfg.get_causal_mask = True

    # Create model
    model = hydra.utils.instantiate(cfg.model)
    test_model = hydra.utils.instantiate(cfg.model)

    # Activation functions
    action_activation = kldiv_activation
    target_activation = normalize_action
    value_activation = tanh_activation
    # action_activation = hydra.utils.call(cfg.action_activation)
    # target_activation = hydra.utils.call(cfg.target_activation)
    # value_activation = hydra.utils.call(cfg.value_activation)

    local_rank = os.environ.get("LOCAL_RANK")
    if local_rank is None:
        local_rank = 0
    local_rank = int(local_rank)

    # Dataset
    log.info(f"[{local_rank}] Creating dataset")

    loading_start_time = time.time()
    dataset = hydra.utils.instantiate(cfg.dataset, embedding_size=model.embedding_size)
    train_dts, eval_dts = dataset.split_data(1 - cfg.train_split, seed=cfg.seed)

    generation_tds = []
    if cfg.test_generation:
        generation_tds = eval_dts.get_generation_data(cfg.test_limit)

    elapsed_time = time_formatter(time.time() - loading_start_time, show_ms=False)
    log.info(
        f"[{local_rank}] Finished loading dataset" f"[Elapsed Time: {elapsed_time}]"
    )
    log.info(
        f"[{local_rank}] Dataset Size: [Training: {len(train_dts)}] "
        f"[Evaluation: {len(eval_dts)}]"
    )
    # Set up Loss function
    policy_criterion = hydra.utils.instantiate(cfg.policy_loss)
    value_criterion = hydra.utils.instantiate(cfg.value_loss)

    # Create optimizer and lr scheduler
    optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
    scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)

    # Start training
    if cfg.use_accelerator:
        accelerator = Accelerator(
            dataloader_config=DataLoaderConfiguration(
                split_batches=False,
                even_batches=False,
            ),
            kwargs_handlers=[
                DistributedDataParallelKwargs(find_unused_parameters=True),
            ],
        )
        trainer = AcceleratorTrainer(
            cfg,
            model,
            test_model,
            policy_criterion,
            value_criterion,
            optimizer,
            scheduler,
            action_activation,
            target_activation,
            value_activation,
            train_dts,
            eval_dts,
            generation_tds,
            accelerator,
        )
    # DDP
    elif cfg.use_ddp:
        ddp_setup()

        trainer = DDPTrainer(
            cfg,
            model,
            test_model,
            policy_criterion,
            value_criterion,
            optimizer,
            scheduler,
            action_activation,
            target_activation,
            value_activation,
            train_dts,
            eval_dts,
            generation_tds,
        )
    # Basic training
    else:
        device = torch.device(cfg.device)
        trainer = BasicTrainer(
            cfg,
            model,
            test_model,
            policy_criterion,
            value_criterion,
            optimizer,
            scheduler,
            action_activation,
            target_activation,
            value_activation,
            train_dts,
            eval_dts,
            generation_tds,
            device,
        )

    trainer.train()

    if cfg.use_ddp:
        destroy_process_group()


if __name__ == "__main__":
    launch_pretraining()
