from functools import total_ordering
from dataclasses import dataclass
import random
import os
import warnings
import math
from collections import namedtuple

warnings.filterwarnings(
    action="ignore", category=UserWarning,
)

import tensorboardX
import tqdm
import torch
import rl_utils as dc

import learn
import run
import replay
from agent_pbt import Agent
from action_repeat_wrapper import ActionRepeatWrapper


@dataclass
class Hparams:
    # actor updates per step
    a: int
    # critic updates per step
    c: int
    # action persistence
    k: int
    # target entropy
    h: float
    # discount factor
    g: float


@total_ordering
class Member:
    """Each population member contains actor (which contains actor and critic(s)) and hparams."""

    def __init__(self, uid, agent, hparams):
        self.id = uid
        self.agent = agent
        self.hparams = hparams
        self.fitness = -float("inf")

    def __eq__(self, other):
        return self.fitness == other.fitness

    def __gt__(self, other):
        return self.fitness > other.fitness


class EnvironmentWrapper:
    def __init__(self, uid, make_env_function, max_episode_steps):
        self.id = uid
        self.make_env_function = make_env_function
        self._env = None
        self.done = True
        self.step_count = 0
        self._max_episode_steps = max_episode_steps

    @property
    def env(self):
        # Lazy instantiation of environment
        if self._env is None:
            self._env = self.make_env_function()
            self.state = self._env.reset()
        return self._env

    @property
    def max_episode_steps(self):
        return round(self._max_episode_steps / self.env.k)

    def set_k(self, new_k):
        self.env.set_k(new_k)


ParamSpace = namedtuple("ParamSpace", ["min", "max", "delta"])


def pbt_ac(
    make_env_function,
    epochs=1000,
    steps_per_epoch=1_000,
    population_size=20,
    a_param=(1, 10, 2),
    c_param=(1, 40, 5),
    h_param=(0.25, 1.75, 0.25),
    k_param=(1, 15, 2),
    g_param=(-6.5, -1.0, 0.5),
    max_episode_steps=1000,
    batch_size=512,
    name="pbt_ac",
):
    """
    Simple, single-thread implmentation of AAC that is easy to read.
    Note that parallel version is used for actual experiments.

    Args:
        make_env_function (callable):
            Zero-argument callable that returns environment as `ActionRepeatWrapper`.
        epochs (int, optional):
            Evolutionary epochs Defaults to 1_000.
        steps_per_epoch (int, optional):
            Training steps per epoch. Defaults to 1_000.
        population_size (int, optional):
            Population size. Defaults to 20.
        a_param (tuple[int, int, int], optional):
            Tuple of min, max, and delta value for hyperparameter `a`. Defaults to (1, 10, 2).
        c_param (tuple, optional):
            Tuple of min, max, and delta value for hyperparameter `c`. Defaults to (1, 40, 5).
        h_param (tuple, optional):
            Tuple of min, max, and delta value for hyperparameter `h`. Defaults to (0.25, 1.75, 0.25).
        k_param (tuple, optional): 
            Tuple of min, max, and delta value for hyperparameter `k`. Defaults to (1, 15, 2).
        g_param (tuple, optional): 
            Tuple of min, max, and delta value for hyperparameter `g`. Defaults to (-6.5, -1.0, 0.5).
        max_episode_steps (int, optional):
            Maximum number of steps for an episode. Defaults to 1000.
        batch_size (int, optional):
            Batch size of experiences from replay buffer used for training. Defaults to 512.
        name (str, optional): 
            Name of run. Used for logging to Tensorboard. Defaults to "pbt_ac".
    """
    a_param = ParamSpace(*a_param)
    c_param = ParamSpace(*c_param)
    h_param = ParamSpace(*h_param)
    k_param = ParamSpace(*k_param)
    g_param = ParamSpace(*g_param)

    test_env = make_env_function()
    obs_space = test_env.observation_space
    act_space = test_env.action_space

    save_dir = dc.utils.make_process_dirs(name)
    writer = tensorboardX.SummaryWriter(save_dir)
    writer.add_hparams(locals(), {})

    # Lambda funcs related to genetic crossover
    clamp = lambda x, param: max(min(x, param.max), param.min)
    shift_int_by_add = lambda current, param: current + random.randint(
        -param.delta, param.delta
    )
    shift_float_by_add = lambda current, param: current + random.uniform(
        -param.delta, param.delta
    )
    make_int_range = lambda x: random.randint(x.min, x.max)
    make_float_range = lambda x: random.uniform(x.min, x.max)

    # Create a centralized replay buffer and add a few random samples to get started
    buffer_size = 2_000_000
    replay_buffer = replay.ReplayBuffer(
        size=buffer_size,
        state_shape=obs_space.shape,
        action_repeat=k_param.max,
        state_dtype=float,
        action_shape=act_space.shape,
    )

    print("Warm up replay buffer...")
    warmup_size = 10_000
    rand_env = make_env_function()
    pbar = tqdm.tqdm(total=warmup_size, dynamic_ncols=True)
    while len(replay_buffer) < warmup_size:
        # Collect random samples at each action repeat value
        for k in range(1, k_param.max + 1):
            prev = len(replay_buffer)
            rand_env.reset()
            rand_env.set_k(k)
            max_steps = round(max_episode_steps / k)
            run.warmup_buffer(
                replay_buffer, rand_env, max_steps + 1, max_steps,
            )
            if len(replay_buffer) >= warmup_size:
                pbar.update(warmup_size - prev)
                break
            else:
                pbar.update(len(replay_buffer) - prev)
    pbar.close()

    # Initialize the population
    population = []
    # EnvironmentWrapper for each population member.
    envs = []
    for i in range(population_size):
        agent = Agent(obs_space.shape[0], act_space.shape[0])
        agent.to(dc.device)
        hparams = Hparams(
            make_int_range(a_param),
            make_int_range(c_param),
            make_int_range(k_param),
            make_float_range(h_param),
            make_float_range(g_param),
        )
        member = Member(i, agent, hparams)
        env_wrapper = EnvironmentWrapper(i, make_env_function, max_episode_steps)
        population.append(member)
        envs.append(env_wrapper)

    for epoch in range(epochs):
        print(f"EPOCH {epoch}")

        # Make sure each environment uses member's k.
        for member, env_wrapper in zip(population, envs):
            env_wrapper.set_k(member.hparams.k)

        for step in tqdm.tqdm(range(steps_per_epoch)):
            for member, env_wrapper in zip(population, envs):
                # Collect experiences from the enviornment.
                run.collect_experience(member, env_wrapper, buffer=replay_buffer)

            for member in population:
                # Do critic updates
                for _ in range(member.hparams.c):
                    learn.learn_critics(
                        member=member,
                        buffer=replay_buffer,
                        batch_size=batch_size,
                        gamma=1.0 - math.exp(member.hparams.g),
                    )

                # Do actor updates
                for _ in range(member.hparams.a):
                    learn.learn_actor(
                        member=member,
                        buffer=replay_buffer,
                        batch_size=batch_size,
                        target_entropy_mul=member.hparams.h,
                    )

        # Evaluate the fitness of population members using the test env.
        for member, env_wrapper in zip(population, envs):
            test_env.set_k(member.hparams.k)
            member.fitness = run.evaluate_agent(
                member.agent, test_env, 10, env_wrapper.max_episode_steps
            )

        # Save final population to disk
        for member in population:
            member.agent.save(save_dir, member.id)

        # Sort the population by increasing average return
        population = sorted(population)
        # Get the bottom and top 20% of the population and randomly shuffle them
        worst_members = population[: (population_size // 5)]
        best_members = population[-(population_size // 5) :]
        random.shuffle(worst_members)
        random.shuffle(best_members)
        # They were shuffled so that zip() creates a random pairing
        for bad, elite in zip(worst_members, best_members):
            # Copy the good agent's network weights
            dc.utils.hard_update(bad.agent.actor, elite.agent.actor)
            dc.utils.hard_update(bad.agent.critic1, elite.agent.critic1)
            dc.utils.hard_update(bad.agent.critic2, elite.agent.critic2)
            # Copy the good agent's optimizers (this may not be necessary)
            bad.agent.online_actor_optimizer.load_state_dict(
                elite.agent.online_actor_optimizer.state_dict()
            )
            bad.agent.critic_optimizer.load_state_dict(
                elite.agent.critic_optimizer.state_dict()
            )
            # Copy the good agent's max ent constraint and optimizer
            bad.agent.log_alpha_optimizer.load_state_dict(
                elite.agent.log_alpha_optimizer.state_dict()
            )
            bad.agent.log_alpha = elite.agent.log_alpha.clone()

            # Explore the param space, clamped within a specified range
            new_a = clamp(shift_int_by_add(elite.hparams.a, a_param), a_param)
            new_c = clamp(shift_int_by_add(elite.hparams.c, c_param), c_param)
            new_g = clamp(shift_float_by_add(elite.hparams.g, g_param), g_param)
            new_k = clamp(shift_int_by_add(elite.hparams.k, k_param), k_param)
            new_h = clamp(shift_float_by_add(elite.hparams.h, h_param), h_param)
            bad.hparams = Hparams(new_a, new_c, new_k, new_h, new_g)

        # Logging
        fitness_distrib = torch.Tensor([m.fitness for m in population]).float()
        a_distrib = torch.Tensor([m.hparams.a for m in population]).float()
        c_distrib = torch.Tensor([m.hparams.c for m in population]).float()
        k_distrib = torch.Tensor([m.hparams.k for m in population]).float()
        h_distrib = torch.Tensor([m.hparams.h for m in population]).float()
        g_distrib = torch.Tensor([m.hparams.g for m in population]).float()

        writer.add_histogram("Fitness", fitness_distrib, epoch)
        writer.add_histogram("A Param", a_distrib, epoch)
        writer.add_histogram("C Param", c_distrib, epoch)
        writer.add_histogram("K Param", k_distrib, epoch)
        writer.add_histogram("H Param", h_distrib, epoch)
        writer.add_histogram("G Param", g_distrib, epoch)

        best_return = population[-1].fitness
        best_a = population[-1].hparams.a
        best_c = population[-1].hparams.c
        best_k = population[-1].hparams.k
        best_h = population[-1].hparams.h
        best_g = population[-1].hparams.g

        writer.add_scalar("BestReturn", best_return, epoch)
        writer.add_scalar("BestA", best_a, epoch)
        writer.add_scalar("BestC", best_c, epoch)
        writer.add_scalar("BestK", best_k, epoch)
        writer.add_scalar("BestH", best_h, epoch)
        writer.add_scalar("BestG", best_g, epoch)
        with open(os.path.join(save_dir, "population_fitness.csv"), "a") as f:
            f.write(",".join([f"{m.fitness.item():.1f}" for m in population]) + "\n")


if __name__ == "__main__":
    import dmc2gym
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--domain", type=str, required=True)
    parser.add_argument("--task", type=str, required=True)
    parser.add_argument("--name", type=str, required=True)
    args = parser.parse_args()

    make_env = lambda: ActionRepeatWrapper(dmc2gym.make(args.domain, args.task))
    test_env = make_env()
    pbt_ac(make_env, name=args.name)
