from functools import total_ordering
from dataclasses import dataclass
import random
import os
import math
import warnings
from collections import namedtuple

warnings.filterwarnings(
    action="ignore", category=UserWarning,
)

import tensorboardX
import tqdm
import torch
from torch import multiprocessing as mp
import rl_utils as dc
import numpy as np

import learn
import run
import replay
from agent_pbt import Agent
from action_repeat_wrapper import ActionRepeatWrapper


@dataclass
class Hparams:
    # actor updates per step
    a: int
    # critic updates per step
    c: int
    # action persistence
    k: int
    # target entropy
    h: float
    # discount factor
    g: float


@total_ordering
class Member:
    """Each population member contains actor (which contains actor and critic(s)) and hparams."""

    def __init__(self, uid, agent, hparams):
        self.id = uid
        self.agent = agent
        self.hparams = hparams
        self.fitness = -float("inf")

    def __eq__(self, other):
        return self.fitness == other.fitness

    def __gt__(self, other):
        return self.fitness > other.fitness


class EnvironmentWrapper:
    def __init__(self, uid, make_env_function, max_episode_steps):
        self.id = uid
        self.make_env_function = make_env_function
        self._env = None
        self.done = True
        self.step_count = 0
        self._max_episode_steps = max_episode_steps

    @property
    def env(self):
        # Lazy instantiation of environment
        if self._env is None:
            self._env = self.make_env_function()
            self.state = self._env.reset()
        return self._env

    @property
    def max_episode_steps(self):
        return round(self._max_episode_steps / self.env.k)

    def set_k(self, new_k):
        self.env.set_k(new_k)


class Worker(mp.Process):
    """Each worker is reponsible for collecting experience, training, and evaluating each member."""

    def __init__(
        self,
        uid,
        make_env_function,
        max_episode_steps,
        replay_buffer,
        member_queue,
        exp_queue,
        step_events,
        epoch_events,
        epochs,
        steps_per_epoch,
        batch_size,
        num_gpus,
    ):
        super().__init__()
        self.id = uid
        self.train_env_wrapper = EnvironmentWrapper(
            uid, make_env_function, max_episode_steps
        )
        self.test_env_wrapper = EnvironmentWrapper(
            uid, make_env_function, max_episode_steps
        )
        self.replay_buffer = replay_buffer
        self.member_queue = member_queue
        self.exp_queue = exp_queue
        self.step_events = step_events
        self.epoch_events = epoch_events
        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        self.batch_size = batch_size
        self.num_gpus = num_gpus

    def run(self):
        # Set sharing strategy to avoid errors when sharing agents to main process.
        torch.multiprocessing.set_sharing_strategy("file_system")
        gpu_id = (
            torch.multiprocessing.current_process()._identity[0] - 1
        ) % self.num_gpus
        torch.cuda.set_device(gpu_id)

        for epoch in range(self.epochs):
            _uid, member = self.member_queue.get()
            assert _uid == self.id, "Worker id and member id mismatch."

            # `requires_grad` needs to be set to `False` before sending it to GPU.
            member.agent.log_alpha.requires_grad = False
            member.agent.to(dc.device)

            # Make sure environment uses member's k.
            self.train_env_wrapper.set_k(member.hparams.k)

            for step in range(self.steps_per_epoch):
                exp = run.collect_experience(member, self.train_env_wrapper)
                self.exp_queue.put((self.id, exp))

                # Wait until main process is done modifying the replay buffer
                if step % 2 == 0:
                    self.step_events[0].wait()
                else:
                    self.step_events[1].wait()

                # Do critic updates
                for _ in range(member.hparams.c):
                    learn.learn_critics(
                        member=member,
                        buffer=self.replay_buffer,
                        batch_size=self.batch_size,
                        gamma=1.0 - math.exp(member.hparams.g),
                    )

                # Do actor updates
                for _ in range(member.hparams.a):
                    learn.learn_actor(
                        member=member,
                        buffer=self.replay_buffer,
                        batch_size=self.batch_size,
                        target_entropy_mul=member.hparams.h,
                    )

            # Evaluate fitness of population members using the test env.
            self.test_env_wrapper.env.reset()
            self.test_env_wrapper.set_k(member.hparams.k)
            member.fitness = run.evaluate_agent(
                member.agent,
                self.test_env_wrapper.env,
                10,
                self.train_env_wrapper.max_episode_steps,
            )

            member.agent.log_alpha.requires_grad = False
            member.agent.to(torch.device("cpu"))
            self.member_queue.put((_uid, member))

            # Wait until main process is done crossover-ing bad and elite members.
            if epoch % 2 == 0:
                self.epoch_events[0].wait()
            else:
                self.epoch_events[1].wait()


class WorkerPool:
    """Simple class to handle pool of workers"""

    def __init__(self, workers, replay_buffer, member_queues, exp_queue):
        self.workers = workers
        self.replay_buffer = replay_buffer
        self.member_queues = member_queues
        self.exp_queue = exp_queue

    def start(self):
        for w in self.workers:
            w.start()

    def join(self):
        for w in self.workers:
            w.join()

    def close(self):
        for w in self.workers:
            w.close()

    def get_population(self):
        population = []
        for worker in self.workers:
            _id, member = self.member_queues[worker.id].get()
            population.append(member)
        return population

    def collect_experiences(self):
        for _ in self.workers:
            _, result = self.exp_queue.get(block=True)
            self.replay_buffer.push(*result)


ParamSpace = namedtuple("ParamSpace", ["min", "max", "delta"])


def parallel_pbt_ac(
    make_env_function,
    epochs=1_000,
    steps_per_epoch=1_000,
    population_size=20,
    a_param=(1, 10, 2),
    c_param=(1, 40, 5),
    h_param=(0.25, 1.75, 0.25),
    k_param=(1, 15, 2),
    g_param=(-6.5, -1.0, 0.5),
    max_episode_steps=1000,
    batch_size=512,
    name="parallel_pbt_ac",
):
    """Parallel Implementation of AAC

    Args:
        make_env_function (callable):
            Zero-argument callable that returns environment as `ActionRepeatWrapper`.
        epochs (int, optional):
            Evolutionary epochs Defaults to 1_000.
        steps_per_epoch (int, optional):
            Training steps per epoch. Defaults to 1_000.
        population_size (int, optional):
            Population size. Defaults to 20.
        a_param (tuple[int, int, int], optional):
            Tuple of min, max, and delta value for hyperparameter `a`. Defaults to (1, 10, 2).
        c_param (tuple, optional):
            Tuple of min, max, and delta value for hyperparameter `c`. Defaults to (1, 40, 5).
        h_param (tuple, optional):
            Tuple of min, max, and delta value for hyperparameter `h`. Defaults to (0.25, 1.75, 0.25).
        k_param (tuple, optional): 
            Tuple of min, max, and delta value for hyperparameter `k`. Defaults to (1, 15, 2).
        g_param (tuple, optional): 
            Tuple of min, max, and delta value for hyperparameter `g`. Defaults to (-6.5, -1.0, 0.5).
        max_episode_steps (int, optional):
            Maximum number of steps for an episode. Defaults to 1000.
        batch_size (int, optional):
            Batch size of experiences from replay buffer used for training. Defaults to 512.
        name (str, optional): 
            Name of run. Used for logging to Tensorboard. Defaults to "parallel_pbt_ac".
    """
    a_param = ParamSpace(*a_param)
    c_param = ParamSpace(*c_param)
    h_param = ParamSpace(*h_param)
    k_param = ParamSpace(*k_param)
    g_param = ParamSpace(*g_param)

    test_env = make_env_function()
    obs_space = test_env.observation_space
    act_space = test_env.action_space

    save_dir = dc.utils.make_process_dirs(name)
    writer = tensorboardX.SummaryWriter(save_dir)
    writer.add_hparams(locals(), {})

    # Lambda funcs related to genetic crossover
    clamp = lambda x, param: max(min(x, param.max), param.min)
    shift_int_by_add = lambda current, param: current + random.randint(
        -param.delta, param.delta
    )
    shift_float_by_add = lambda current, param: current + random.uniform(
        -param.delta, param.delta
    )
    make_int_range = lambda x: random.randint(x.min, x.max)
    make_float_range = lambda x: random.uniform(x.min, x.max)

    # Create a centralized replay buffer and add a few random samples to get started
    buffer_size = 2_000_000
    replay_buffer = replay.ReplayBuffer(
        size=buffer_size,
        state_shape=obs_space.shape,
        action_repeat=k_param.max,
        state_dtype=float,
        action_shape=act_space.shape,
    )

    print("Warm up replay buffer...")
    warmup_size = 10_000
    rand_env = make_env_function()
    pbar = tqdm.tqdm(total=warmup_size, dynamic_ncols=True)
    while len(replay_buffer) < warmup_size:
        # Collect random samples at each action repeat value
        for k in range(1, k_param.max + 1):
            prev = len(replay_buffer)
            rand_env.reset()
            rand_env.set_k(k)
            max_steps = round(max_episode_steps / k)
            run.warmup_buffer(
                replay_buffer, rand_env, max_steps + 1, max_steps,
            )
            if len(replay_buffer) >= warmup_size:
                pbar.update(warmup_size - prev)
                break
            else:
                pbar.update(len(replay_buffer) - prev)
    pbar.close()

    # Initialize the population
    population = []
    for i in range(population_size):
        agent = Agent(obs_space.shape[0], act_space.shape[0])
        hparams = Hparams(
            make_int_range(a_param),
            make_int_range(c_param),
            make_int_range(k_param),
            make_float_range(h_param),
            make_float_range(g_param),
        )
        member = Member(i, agent, hparams)
        population.append(member)

    # Moving replay buffer to shared memory allows us to share it among workers without copying.
    replay_buffer.share_memory_()
    # Separate queue is created for each member to ensure correct member is sent to each worker.
    member_queues = {m.id: mp.Queue() for m in population}
    # Queue for sharing collect experiences
    exp_queue = mp.Queue()

    # Events for synchronization. Two different events are used to avoid possible race conditions.
    # Specifically, we need to clear each event before reusing it, but we don't want to clear it too early.
    step_events = (mp.Event(), mp.Event())
    epoch_events = (mp.Event(), mp.Event())
    num_gpus = torch.cuda.device_count()

    # Initialize workers.
    workers = [
        Worker(
            i,
            make_env_function,
            max_episode_steps,
            replay_buffer,
            member_queues[i],
            exp_queue,
            step_events,
            epoch_events,
            epochs,
            steps_per_epoch,
            batch_size,
            num_gpus,
        )
        for i in range(len(population))
    ]
    pool = WorkerPool(workers, replay_buffer, member_queues, exp_queue)
    pool.start()

    for epoch in range(epochs):
        print(f"EPOCH {epoch}")
        for member in population:
            member_queues[member.id].put((member.id, member))

        for step in tqdm.tqdm(range(steps_per_epoch), dynamic_ncols=True):
            # Push collected experiences to shared replay buffer
            pool.collect_experiences()

            # Notify workers that they're good to proceed with training.
            if step % 2 == 0:
                step_events[0].set()
                step_events[1].clear()
            else:
                step_events[1].set()
                step_events[0].clear()

        # Get members that have been trained
        population = pool.get_population()

        # Save final population to disk
        for member in population:
            member.agent.save(save_dir, member.id)

        # Sort the population by increasing average return
        population = sorted(population)
        # Gget the bottom and top 20% of the population and randomly shuffle them
        worst_members = population[: (population_size // 5)]
        best_members = population[-(population_size // 5) :]
        random.shuffle(worst_members)
        random.shuffle(best_members)
        # They were shuffled so that zip() creates a random pairing
        for bad, elite in zip(worst_members, best_members):
            # Copy the good agent's network weights
            dc.utils.hard_update(bad.agent.actor, elite.agent.actor)
            dc.utils.hard_update(bad.agent.critic1, elite.agent.critic1)
            dc.utils.hard_update(bad.agent.critic2, elite.agent.critic2)
            # Copy the good agent's optimizers (this may not be necessary)
            bad.agent.online_actor_optimizer.load_state_dict(
                elite.agent.online_actor_optimizer.state_dict()
            )
            bad.agent.critic_optimizer.load_state_dict(
                elite.agent.critic_optimizer.state_dict()
            )
            # Copy the good agent's max ent constraint and optimizer
            bad.agent.log_alpha_optimizer.load_state_dict(
                elite.agent.log_alpha_optimizer.state_dict()
            )
            bad.agent.log_alpha = elite.agent.log_alpha.clone()

            # Explore the param space, clamped within a specified range
            new_a = clamp(shift_int_by_add(elite.hparams.a, a_param), a_param)
            new_c = clamp(shift_int_by_add(elite.hparams.c, c_param), c_param)
            new_g = clamp(shift_float_by_add(elite.hparams.g, g_param), g_param)
            new_k = clamp(shift_int_by_add(elite.hparams.k, k_param), k_param)
            new_h = clamp(shift_float_by_add(elite.hparams.h, h_param), h_param)
            bad.hparams = Hparams(new_a, new_c, new_k, new_h, new_g)

        # Logging
        fitness_distrib = torch.Tensor([m.fitness for m in population]).float()
        a_distrib = torch.Tensor([m.hparams.a for m in population]).float()
        c_distrib = torch.Tensor([m.hparams.c for m in population]).float()
        k_distrib = torch.Tensor([m.hparams.k for m in population]).float()
        h_distrib = torch.Tensor([m.hparams.h for m in population]).float()
        g_distrib = torch.Tensor([m.hparams.g for m in population]).float()

        writer.add_histogram("Fitness", fitness_distrib, epoch)
        writer.add_histogram("A Param", a_distrib, epoch)
        writer.add_histogram("C Param", c_distrib, epoch)
        writer.add_histogram("K Param", k_distrib, epoch)
        writer.add_histogram("H Param", h_distrib, epoch)
        writer.add_histogram("G Param", g_distrib, epoch)

        best_return = population[-1].fitness
        best_a = population[-1].hparams.a
        best_c = population[-1].hparams.c
        best_k = population[-1].hparams.k
        best_h = population[-1].hparams.h
        best_g = population[-1].hparams.g

        writer.add_scalar("BestReturn", best_return, epoch)
        writer.add_scalar("BestA", best_a, epoch)
        writer.add_scalar("BestC", best_c, epoch)
        writer.add_scalar("BestK", best_k, epoch)
        writer.add_scalar("BestH", best_h, epoch)
        writer.add_scalar("BestG", best_g, epoch)
        with open(os.path.join(save_dir, "population_fitness.csv"), "a") as f:
            f.write(",".join([f"{m.fitness.item():.1f}" for m in population]) + "\n")

        # Notify workers they're good to proceed with next epoch.
        if epoch % 2 == 0:
            epoch_events[0].set()
            epoch_events[1].clear()
        else:
            epoch_events[1].set()
            epoch_events[0].clear()


"""
Parallel implementation requires us to pass a the function to create
new environments between processes. These functions need to be picklable.
A quick solution is to make them global:
"""


import gym

try:
    import dmc2gym
except:
    "MuJoCo not found or dmc2gym not installed. Skipping..."
    pass


try:
    import or_gym
    from industrial_benchmark_python.IBGym import IBGym
except:
    "or-gym or industrial_benchmark_python packages not installed. Skipping..."
    pass


def fish_swim():
    return ActionRepeatWrapper(dmc2gym.make("fish", "swim"))


def walker_run():
    return ActionRepeatWrapper(dmc2gym.make("walker", "run"))


def swimmer_swimmer6():
    return ActionRepeatWrapper(dmc2gym.make("swimmer", "swimmer6"))


def humanoid_stand():
    return ActionRepeatWrapper(dmc2gym.make("humanoid", "stand"))


def reacher_hard():
    return ActionRepeatWrapper(dmc2gym.make("reacher", "hard"))


def cheetah_run():
    return ActionRepeatWrapper(dmc2gym.make("cheetah", "run"))


def bipedal_hardcore():
    # AAC gets good results in bipedal, although they weren't used in the paper
    # because the task is too similar to the DMC benchmarks.
    return ActionRepeatWrapper(gym.make("BipedalWalkerHardcore-v3"))


def _ib(setpoint):
    return ActionRepeatWrapper(
        IBGym(
            setpoint=setpoint,
            reward_type="classic",
            action_type="continuous",
            observation_type="include_past",
        )
    )


def industrial_benchmark_70():
    return _ib(70)


def industrial_benchmark_100():
    return _ib(100)


def inventory():
    return ActionRepeatWrapper(
        dc.envs.NormalizeContinuousActionSpace(gym.make("or_gym:InvManagement-v1"))
    )


def newsvendor():
    # A quick random agent baseline shows the Newsvendor rewards are far too large.
    # We scale by 1e-4.
    return ActionRepeatWrapper(
        dc.envs.ScaleReward(
            dc.envs.NormalizeContinuousActionSpace(gym.make("or_gym:Newsvendor-v0")),
            1e-4,
        )
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--make_env_func", type=str, required=True, help="Name of domain & task to use."
    )
    parser.add_argument("--name", type=str, required=True, help="Name of the run.")
    parser.add_argument(
        "--epochs", type=int, default=250, help="Number of evolutionary epochs."
    )
    parser.add_argument(
        "--num_seeds", type=int, default=1, help="Number of trials with random seeds."
    )
    parser.add_argument(
        "--steps_per_epoch",
        type=int,
        default=1000,
        help="Number of training steps per epoch.",
    )
    parser.add_argument(
        "--population_size", type=int, default=20, help="Population size"
    )
    parser.add_argument(
        "--max_episode_steps",
        type=int,
        default=1000,
        help="Maximum steps of an episode.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=512,
        help="Batch size of experiences from replay buffer used for training",
    )
    parser.add_argument(
        "--a_max", type=int, default=10, help="Maximum value for hyperparam `a`."
    )
    parser.add_argument(
        "--a_delta", type=int, default=2, help="Delta value for hyperparam `a`."
    )
    parser.add_argument(
        "--c_max", type=int, default=40, help="Maximum value for hyperparam `c`."
    )
    parser.add_argument(
        "--c_delta", type=int, default=5, help="Delta value for hyperparam `c`."
    )
    parser.add_argument(
        "--k_max", type=int, default=15, help="Maximum value for hyperparam `k`."
    )
    args = parser.parse_args()

    torch.multiprocessing.set_start_method("spawn")
    torch.multiprocessing.set_sharing_strategy("file_system")

    for _ in range(args.num_seeds):
        parallel_pbt_ac(
            eval(args.make_env_func),
            name=args.name,
            epochs=args.epochs,
            steps_per_epoch=args.steps_per_epoch,
            population_size=args.population_size,
            max_episode_steps=args.max_episode_steps,
            batch_size=args.batch_size,
            a_param=(1, args.a_max, args.a_delta),
            c_param=(1, args.c_max, args.c_delta),
            k_param=(1, args.k_max, 2),
        )
