import glob
import logging
import math
import os
import random
import sys
import time
from dataclasses import dataclass

from multiprocessing.synchronize import Event as EventClass, Semaphore as SemaphoreClass
from typing import Any, Dict

import hydra

import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp

from omegaconf import MISSING, OmegaConf

sys.path.append("..")
from rl.her import HERConfig
from rl.mcts_policy import AlphaZeroConfig

from loss.loss_lib import *
from data.data_lib import *
from model.model_lib import *
from optimizer.optimizer_lib import *
from scheduler.scheduler_lib import *
from trainer.finetuner import FineTuningConfig, FineTuningNode

REPLAY_BUFFER_NODE = "ReplayBuffer"
TRAINER_NODE = "Trainer"


@dataclass
class AIGEnvConfig:
    reward_type: str = "simple"
    negative_reward: float = -1.0
    const_node: bool = True


cs = ConfigStore.instance()
cs.store(name="AlphaZero_config", node=AlphaZeroConfig, group="policy")
cs.store(name="HER_config", node=HERConfig, group="policy")
cs.store(name="finetuning_base_config", node=FineTuningConfig)


def get_model_path(cfg) -> Tuple[str, str]:
    # Path resolution
    if cfg.path == "latest" or cfg.path is None:
        path = max(glob.glob(os.path.join("runs/", "*/")), key=os.path.getmtime)
        model_config_path = path + "/config.yaml"
        model_path = path + (
            cfg.model_name if ".pt" in cfg.model_name else cfg.model_name + ".pt"
        )
    else:
        model_config_path = cfg.path + "/config.yaml"
        model_path = cfg.path + (
            cfg.model_name if ".pt" in cfg.model_name else cfg.model_name + ".pt"
        )
    return model_config_path, model_path


def set_device_map(cfg: FineTuningConfig):
    devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")
    if devices == [""]:
        devices = []

    device_idx = 0

    device_map_str = ""

    for rank in range(
        cfg.trainer_rank_offset, cfg.trainer_rank_offset + cfg.trainer_local_world_size
    ):
        # assign device
        device_id = None
        if device_idx < len(devices):
            device_id = devices[device_idx]
            device_map_str += f"{rank}:{device_id},"
            device_idx += 1

    for rank in range(
        cfg.data_collectors_rank_offset,
        cfg.data_collectors_rank_offset + cfg.data_collectors_local_world_size,
    ):
        device_id = None
        if device_idx < len(devices):
            device_id = devices[device_idx]
            device_map_str += f"{rank}:{device_id},"
            device_idx += 1

    if len(device_map_str) > 0:
        device_map_str = device_map_str[:-1]

    os.environ["DEVICE_MAP"] = device_map_str


def launch_training_node(
    rank: int, world_size: int, cfg: FineTuningConfig, **tensorpipe_kwargs
):
    logger = logging.getLogger(__name__)

    options = rpc.TensorPipeRpcBackendOptions(  # type: ignore
        num_worker_threads=128,
        init_method=f"tcp://{cfg.master_addr}:{cfg.master_port_rpc}",
        **tensorpipe_kwargs,
    )

    rpc.init_rpc(
        TRAINER_NODE,
        rank=rank,
        rpc_backend_options=options,
        world_size=world_size,
    )

    logger.info(f"Initiled RPC for Trainer Node {rank}")

    trainer = FineTuningNode(cfg=cfg, replay_buffer_node=REPLAY_BUFFER_NODE)
    trainer.train()
    trainer.stop_collect()
    trainer.cleanup()
    rpc.shutdown()


def launch_data_collection_node(
    rank: int,
    world_size: int,
    master_addr: str,
    master_port_rpc: str,
    **tensorpipe_kwargs,
):
    logger = logging.getLogger(__name__)

    options = rpc.TensorPipeRpcBackendOptions(  # type: ignore
        num_worker_threads=128,
        init_method=f"tcp://{master_addr}:{master_port_rpc}",
        **tensorpipe_kwargs,
    )

    rpc.init_rpc(
        f"DataCollector{rank}",
        rank=rank,
        rpc_backend_options=options,
        world_size=world_size,
    )

    logger.info(f"Initiled RPC for Data Collector Node {rank}")

    rpc.shutdown()


def launch_replay_buffer_node(
    rank: int,
    world_size: int,
    master_addr: str,
    master_port_rpc: str,
    **tensorpipe_kwargs,
):
    logger = logging.getLogger(__name__)

    options = rpc.TensorPipeRpcBackendOptions(  # type: ignore
        num_worker_threads=128,
        init_method=f"tcp://{master_addr}:{master_port_rpc}",
        **tensorpipe_kwargs,
    )

    rpc.init_rpc(
        REPLAY_BUFFER_NODE,
        rank=rank,
        rpc_backend_options=options,
        world_size=world_size,
    )
    logger.info(f"Initiled RPC for Replay Buffer Node {rank}")

    rpc.shutdown()


@hydra.main(version_base=None, config_path="../conf", config_name="finetuning_config")
def launch_finetuning(cfg: FineTuningConfig):

    ctx = mp.get_context("spawn")
    processes = []

    model_config_path, model_path = get_model_path(cfg)
    model_cfg = ShortCircuitConfig(**(OmegaConf.to_container(OmegaConf.load(model_config_path).model)))  # type: ignore

    cfg.model = model_cfg

    # Assigns devices to each process
    set_device_map(cfg)

    assert cfg.trainer_global_world_size > 0
    assert cfg.replay_buffer_global_world_size > 0
    assert cfg.data_collectors_global_world_size > 0

    world_size = (
        cfg.trainer_global_world_size
        + cfg.replay_buffer_global_world_size
        + cfg.data_collectors_global_world_size
    )

    if cfg.nnodes == 1:
        cfg.trainer_local_world_size = cfg.trainer_global_world_size
        cfg.replay_buffer_local_world_size = cfg.replay_buffer_global_world_size
        cfg.data_collectors_local_world_size = cfg.data_collectors_global_world_size
        cfg.data_collectors_rank_offset = cfg.trainer_global_world_size
        cfg.master_addr = "localhost"

    cfg.replay_buffer_rank_offset = (
        cfg.trainer_global_world_size + cfg.data_collectors_global_world_size
    )

    # Start replay buffer node
    for rank in range(
        cfg.replay_buffer_rank_offset,
        cfg.replay_buffer_rank_offset + cfg.replay_buffer_local_world_size,
    ):
        p = ctx.Process(
            target=launch_replay_buffer_node,
            args=(
                rank,
                world_size,
                cfg.master_addr,
                cfg.master_port_rpc,
            ),
        )
        p.start()
        processes.append(p)

    # Start training node
    for rank in range(
        cfg.trainer_rank_offset, cfg.trainer_rank_offset + cfg.trainer_local_world_size
    ):
        p = ctx.Process(
            target=launch_training_node,
            args=(
                rank,
                world_size,
                cfg,
            ),
        )
        p.start()
        processes.append(p)

    # Start data collection nodes
    for rank in range(
        cfg.data_collectors_rank_offset,
        cfg.data_collectors_rank_offset + cfg.data_collectors_local_world_size,
    ):
        p = ctx.Process(
            target=launch_data_collection_node,
            args=(
                rank,
                world_size,
                cfg.master_addr,
                cfg.master_port_rpc,
            ),
        )
        p.start()
        processes.append(p)

    for p in reversed(processes):
        p.join()


if __name__ == "__main__":
    launch_finetuning()
