import os
import time
from pathlib import Path
from functools import partial
from tqdm import tqdm
from typing import Union, Any
from contextlib import contextmanager
from collections import deque
from collections.abc import Callable, Generator, MutableMapping

from beartype import beartype
from omegaconf import OmegaConf, DictConfig
from einops import rearrange
from termcolor import colored
import wandb
from wandb.errors import CommError
import numpy as np
from scipy.stats import entropy
import torch
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from gymnasium.core import Env
from gymnasium.vector.vector_env import VectorEnv

from helpers import logger
from agents.agent import Agent


@contextmanager
@beartype
def timed(op: str, timer: Callable[[], float]):
    logger.info(colored(
        f"starting timer | op: {op}",
        "magenta", attrs=["underline", "bold"]))
    tstart = timer()
    yield
    tot_time = timer() - tstart
    logger.info(colored(
        f"stopping timer | op took {tot_time}secs",
        "magenta"))


@beartype
def segment(env: Union[Env, VectorEnv],
            agent: Agent,
            seed: int,
            segment_len: int,
            learning_starts: int,
            action_repeat: int,
            *,
            check_if_envs_in_sync: bool = False,
    ) -> Generator[None, None, None]:
    """Optionally check the entropy of the done vectors to see if the resets across env pool
    are in sync. Two behaviors:
    - High entropy:
      Resets are uniformly and independently distributed across environments.
      No significant patterns or synchronizations exist.
    - Low entropy:
      Resets are clustered or synchronized, indicating that multiple environments are
      resetting simultaneously more often than expected by chance.
    """

    assert agent.rb is not None

    obs, _ = env.reset(seed=seed)  # for the very first reset, we give a seed (and never again)
    obs = torch.as_tensor(obs, device=agent.device, dtype=torch.float)
    actions = None  # as long as r is init at 0: ac will be written over

    if agent.hps.method == "pwil":
        assert hasattr(agent, "pwil_rewarder")
        for idx in range(env.num_envs):
            agent.pwil_rewarder[idx].reset()

    t = 0
    r = 0  # action repeat reference

    if check_if_envs_in_sync:
        reset_entropies = []

    while True:

        if r % action_repeat == 0:
            # predict action
            if agent.timesteps_so_far < learning_starts:
                actions_nt = env.action_space.sample()
            else:
                actions_nt = agent.predict(
                    TensorDict(
                        {
                            "observations": obs,
                        },
                        device=agent.device,
                    ),
                    explore=True,
                )

        if t > 0 and t % segment_len == 0:
            yield

        # interact with env (while avoiding reward leakage)
        next_obs, _, terminations, truncations, infos = env.step(actions_nt)

        if agent.hps.method == "pwil":
            dones = np.logical_or(np.array(terminations), np.array(truncations))

        if check_if_envs_in_sync:
            resets_this_t = (terminations | truncations).astype(int)
            p = resets_this_t / env.num_envs  # make probs
            eps = 1e-10  # we add a small epsilon to avoid log(0)
            p = np.clip(p, eps, 1. - eps)
            reset_entropy = entropy(p, base=2)  # using base 2 for entropy in bits
            reset_entropies.append(reset_entropy)
            if t % 8_000 == 0:
                avg_entropy = np.mean(reset_entropies)
                wandb.log({"vitals/reset_entropy": avg_entropy})
                reset_entropies = []

        actions = torch.as_tensor(actions_nt, device=agent.device, dtype=torch.float)

        next_obs = torch.as_tensor(next_obs, device=agent.device, dtype=torch.float)
        real_next_obs = next_obs.clone()

        for idx, trunc in enumerate(np.array(truncations)):
            if trunc:
                real_next_obs[idx] = torch.as_tensor(
                    infos["final_observation"][idx], device=agent.device, dtype=torch.float)

        terminations = rearrange(
            torch.as_tensor(terminations, device=agent.device, dtype=torch.bool),
            "b -> b 1",
        )

        transitions = {
            "observations": obs,
            "next_observations": real_next_obs,
            "actions": actions,
            "terminations": terminations,
            "dones": terminations,
        }

        if agent.hps.method == "pwil":
            pwil_rewards = []
            for idx, done in enumerate(dones):
                if done:
                    agent.pwil_rewarder[idx].reset()
                pwil_reward = agent.pwil_rewarder[idx].compute_reward(
                    obs[idx, ...], actions[idx, ...], real_next_obs[idx, ...])
                pwil_rewards.append(pwil_reward)

            pwil_rewards = torch.stack(pwil_rewards)
            transitions.update({"pwil_rewards": pwil_rewards})

        agent.rb.extend(
            TensorDict(
                {
                    **transitions,
                },
                batch_size=obs.shape[0],
                device=agent.device,
            ),
        )

        obs = next_obs

        t += 1
        r += 1


@beartype
def episode(env: Env,
            agent: Agent,
            seed: int,
    ) -> Generator[dict[str, np.ndarray], None, None]:
    # generator that spits out a trajectory collected during a single episode

    # `append` operation is significantly faster on lists than numpy arrays,
    # they will be converted to numpy arrays once complete right before the yield

    rng = np.random.default_rng(seed)  # aligned on seed, so always reproducible

    def randomize_seed() -> int:
        return seed + rng.integers(2**32 - 1, size=1).item()
        # seeded Generator: deterministic -> reproducible

    obs_list = []
    next_obs_list = []
    actions_list = []

    ob, _ = env.reset(seed=randomize_seed())
    obs_list.append(ob)
    ob = torch.as_tensor(ob, device=agent.device, dtype=torch.float)

    while True:

        if agent.hps.method == "random":
            action = env.action_space.sample()
        else:
            # predict action
            action = agent.predict(
                TensorDict(
                    {
                        "observations": ob,
                    },
                    device=agent.device,
                ),
                explore=False,
            )

        new_ob, _, termination, truncation, infos = env.step(action)

        done = termination or truncation

        next_obs_list.append(new_ob)
        actions_list.append(action)
        if not done:
            obs_list.append(new_ob)

        new_ob = torch.as_tensor(new_ob, device=agent.device, dtype=torch.float)
        ob = new_ob

        if "final_info" in infos:
            # we have len(infos["final_info"]) == 1
            for info in infos["final_info"]:
                ep_len = float(info["episode"]["l"].item())
                ep_ret = float(info["episode"]["r"].item())

            yield {
                "observations": np.array(obs_list),
                "actions": np.array(actions_list),
                "next_observations": np.array(next_obs_list),
                "length": np.array(ep_len),
                "return": np.array(ep_ret),
            }

            obs_list = []
            next_obs_list = []
            actions_list = []

            ob, _ = env.reset(seed=randomize_seed())
            obs_list.append(ob)
            ob = torch.as_tensor(ob, device=agent.device, dtype=torch.float)


@beartype
def train(cfg: MutableMapping[Any, Any],
          env: Union[Env, VectorEnv],
          eval_env: Env,
          agent_wrapper: Callable[[], Agent],
          name: str,
          progress_files: dict[str, Path]):

    assert isinstance(cfg, DictConfig)

    agent = agent_wrapper()

    assert agent.rb is not None

    # set up wandb
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    group = ".".join(name.split(".")[:-1])  # everything in name except seed
    logger.warn(f"{name=}")
    logger.warn(f"{group=}")
    while True:
        try:
            config = OmegaConf.to_object(cfg)
            assert isinstance(config, dict)
            wandb.init(
                project=cfg.wandb_project,
                name=name,
                id=name,
                group=group,
                config=config,
                dir=cfg.root,
            )
            break
        except CommError:
            pause = 10
            logger.info(f"wandb co error. Retrying in {pause} secs.")
            time.sleep(pause)
    logger.info("wandb co established!")

    # create segment generator for training the agent
    seg_gen = segment(
        env, agent, cfg.seed, cfg.segment_len, cfg.learning_starts, cfg.action_repeat,
    )
    # create episode generator for evaluating the agent
    ep_gen = episode(eval_env, agent, cfg.seed)

    i = 0
    start_time = None
    measure_burnin = None
    pbar = tqdm(range(cfg.num_timesteps))
    time_spent_eval = 0

    tlog = TensorDict({})
    maxlen = cfg.deque_window * cfg.eval_steps
    len_buff = deque(maxlen=maxlen)
    ret_buff = deque(maxlen=maxlen)
    if cfg.wasserstein_two:
        w2d_buff = deque(maxlen=maxlen)

    mode = None
    tc_update_actor = agent.update_actor
    tc_update_qnets = agent.update_qnets
    tc_update_reward = agent.update_reward
    if cfg.compile:
        tc_update_actor = torch.compile(tc_update_actor, mode=mode)
        tc_update_qnets = torch.compile(tc_update_qnets, mode=mode)
        tc_update_reward = torch.compile(tc_update_reward, mode=mode)
    if cfg.cudagraphs:
        cuda_graph_module = partial(CudaGraphModule, warmup=2, in_keys=[], out_keys=[])
        tc_update_actor = cuda_graph_module(tc_update_actor)
        tc_update_qnets = cuda_graph_module(tc_update_qnets)
        tc_update_reward = cuda_graph_module(tc_update_reward)

    if cfg.pretrain:
        logger.info(("pretrain").upper())
        agent.load(cfg.load_ckpt)

    while agent.timesteps_so_far <= cfg.num_timesteps:

        if ((agent.timesteps_so_far >= (cfg.measure_burnin + cfg.learning_starts)) and
            (start_time is None)):
            start_time = time.time()
            measure_burnin = agent.timesteps_so_far

        logger.info(("interact").upper())
        next(seg_gen)
        agent.timesteps_so_far += (increment := cfg.segment_len * cfg.num_envs)
        pbar.update(increment)

        if agent.timesteps_so_far <= cfg.learning_starts:
            # start training when enough data
            pbar.set_description("not learning yet")
            i += 1
            continue

        logger.info(("train").upper())

        if cfg.method != "random":
            # sample batch of transitions
            batch = agent.rb.sample(cfg.batch_size)
            # update qnets
            tlog.update(tc_update_qnets(batch))
            agent.qnet_updates_so_far += 1
            # update actor (and alpha)
            if i % (cfg.actor_update_delay + 1) == 0:  # eval freq even number
                # compensate for delay: wait X rounds, do X updates
                for _ in range(cfg.actor_update_delay):
                    tlog.update(tc_update_actor(batch))
                    agent.actor_updates_so_far += 1
            # update the target networks
            agent.update_targ_nets()
            # sample batch of transitions (a new one)
            batch = agent.rb.sample(cfg.batch_size)
            # sample batch of expert data
            demos = agent.expert_dataset.sample(cfg.batch_size)
            # update reward
            tlog.update(tc_update_reward(batch, demos))
            agent.reward_updates_so_far += 1

        if agent.timesteps_so_far % cfg.eval_every == 0:
            logger.info(("eval").upper())
            eval_start = time.time()

            for _ in range(cfg.eval_steps):
                ep = next(ep_gen)
                len_buff.append(torch.as_tensor(ep["length"], dtype=torch.float))  # cpu
                ret_buff.append(torch.as_tensor(ep["return"], dtype=torch.float))  # cpu
                if cfg.wasserstein_two:
                    w2d = agent.distance_to_expert(ep)
                    if w2d is not None:
                        w2d_buff.append(w2d.cpu())

            with torch.no_grad():

                @beartype
                def _wrapper(tensor_buff: deque[torch.Tensor]) -> torch.Tensor:
                    return torch.stack(list(tensor_buff)).mean()

                eval_metrics = {
                    "length": _wrapper(len_buff),
                    "return": _wrapper(ret_buff),
                }
                if cfg.wasserstein_two:
                    eval_metrics.update({"w2dist": _wrapper(w2d_buff)})

            # log with logger
            logger.record_tabular("timestep", agent.timesteps_so_far)
            for k, v in eval_metrics.items():
                logger.record_tabular(k, v.numpy())
            logger.dump_tabular()

            # log with wandb
            for v in progress_files.values():
                wandb.save(v, base_path=str(v.parent))
            wandb.log(
                {
                    **tlog.to_dict(),
                    **{f"eval/{k}": v for k, v in eval_metrics.items()},
                    "vitals/replay_buffer_numel": len(agent.rb),
                },
                step=agent.timesteps_so_far,
            )

            time_spent_eval += time.time() - eval_start

            if start_time is not None:
                # compute the speed in steps per second
                speed = (
                    (agent.timesteps_so_far - measure_burnin) /
                    (time.time() - start_time - time_spent_eval)
                )
                desc = f"speed={speed: 4.4f} sps"
                pbar.set_description(desc)
                wandb.log(
                    {
                        "vitals/speed": speed,
                    },
                    step=agent.timesteps_so_far,
                )

        i += 1
        tlog.clear()

    # mark a run as finished, and finish uploading all data (from docs)
    wandb.finish()
    logger.warn("bye")


@beartype
def evaluate(cfg: MutableMapping[Any, Any],
             env: Env,
             agent_wrapper: Callable[[], Agent],
             name: str):

    assert isinstance(cfg, DictConfig)

    agent = agent_wrapper()

    agent.load(cfg.load_ckpt)

    # create episode generator
    ep_gen = episode(env, agent, cfg.seed)

    pbar = tqdm.tqdm(range(cfg.num_episodes))
    pbar.set_description("evaluating")

    len_list = []
    ret_list = []
    if cfg.wasserstein_two:
        w2d_list = []

    for _ in pbar:

        ep = next(ep_gen)
        len_list.append(torch.as_tensor(ep["length"], dtype=torch.float))  # cpu
        ret_list.append(torch.as_tensor(ep["return"], dtype=torch.float))  # cpu
        if cfg.wasserstein_two:
            w2d = agent.distance_to_expert(ep)
            if w2d is not None:
                w2d_list.append(w2d.cpu())

    with torch.no_grad():

        @beartype
        def _wrapper(tensor_list: list[torch.Tensor]) -> torch.Tensor:
            return torch.stack(tensor_list).mean()

        eval_metrics = {
            "length": _wrapper(len_list),
            "return": _wrapper(ret_list),
        }
        if cfg.wasserstein_two:
            eval_metrics.update({"w2dist": _wrapper(w2d_list)})

    # log with logger
    for k, v in eval_metrics.items():
        logger.record_tabular(k, v.numpy())
    logger.dump_tabular()


@beartype
def clone(cfg: MutableMapping[Any, Any],
          eval_env: Env,
          agent_wrapper: Callable[[], Agent],
          name: str,
          progress_files: dict[str, Path]):

    assert isinstance(cfg, DictConfig)

    agent = agent_wrapper()

    assert agent.rb is not None

    # set up model save directory
    ckpt_dir = Path(cfg.checkpoint_dir) / name
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # set up wandb
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    group = ".".join(name.split(".")[:-1])  # everything in name except seed
    logger.warn(f"{name=}")
    logger.warn(f"{group=}")
    while True:
        try:
            config = OmegaConf.to_object(cfg)
            assert isinstance(config, dict)
            wandb.init(
                project=cfg.wandb_project,
                name=name,
                id=name,
                group=group,
                config=config,
                dir=cfg.root,
            )
            break
        except CommError:
            pause = 10
            logger.info(f"wandb co error. Retrying in {pause} secs.")
            time.sleep(pause)
    logger.info("wandb co established!")

    # create episode generator for evaluating the agent
    ep_gen = episode(eval_env, agent, cfg.seed)

    start_time = None
    measure_burnin = None
    pbar = tqdm(range(cfg.num_bc_iters))
    time_spent_eval = 0

    tlog = TensorDict({})
    maxlen = cfg.deque_window * cfg.eval_steps
    len_buff = deque(maxlen=maxlen)
    ret_buff = deque(maxlen=maxlen)
    if cfg.wasserstein_two:
        w2d_buff = deque(maxlen=maxlen)

    mode = None
    tc_update_actor = agent.behavioral_cloning
    if cfg.compile:
        tc_update_actor = torch.compile(tc_update_actor, mode=mode)
    if cfg.cudagraphs:
        tc_update_actor = CudaGraphModule(tc_update_actor, in_keys=[], out_keys=[])

    for i in pbar:

        if (agent.actor_updates_so_far >= cfg.measure_burnin) and (start_time is None):
            start_time = time.time()
            measure_burnin = i

        # sample batch of expert data
        demos = agent.expert_dataset.sample(cfg.batch_size)
        # update actor
        tlog.update(tc_update_actor(demos))

        if i % cfg.eval_every == 0:
            logger.info(("eval").upper())
            eval_start = time.time()

            for _ in range(cfg.eval_steps):
                ep = next(ep_gen)
                len_buff.append(torch.as_tensor(ep["length"], dtype=torch.float))  # cpu
                ret_buff.append(torch.as_tensor(ep["return"], dtype=torch.float))  # cpu
                if cfg.wasserstein_two:
                    w2d = agent.distance_to_expert(ep)
                    if w2d is not None:
                        w2d_buff.append(w2d.cpu())

            with torch.no_grad():

                @beartype
                def _wrapper(tensor_buff: deque[torch.Tensor]) -> torch.Tensor:
                    return torch.stack(list(tensor_buff)).mean()

                eval_metrics = {
                    "length": _wrapper(len_buff),
                    "return": _wrapper(ret_buff),
                }
                if cfg.wasserstein_two:
                    eval_metrics.update({"w2dist": _wrapper(w2d_buff)})

            # log with logger
            logger.record_tabular("iteration", i)
            for k, v in eval_metrics.items():
                logger.record_tabular(k, v.numpy())
            logger.dump_tabular()

            # log with wandb
            if (new_best := eval_metrics["return"].item()) > agent.best_eval_ep_ret:
                # save the new best model
                logger.info("new best eval! -- saving model to disk and wandb")
                agent.best_eval_ep_ret = new_best
                agent.save(ckpt_dir, sfx="best")
            for v in progress_files.values():
                wandb.save(v, base_path=str(v.parent))
            wandb.log(
                {
                    **tlog.to_dict(),
                    **{f"eval/{k}": v for k, v in eval_metrics.items()},
                    "vitals/replay_buffer_numel": len(agent.rb),
                },
                step=i,
            )

            time_spent_eval += time.time() - eval_start

            if start_time is not None:
                # compute the speed in steps per second
                speed = (
                    (i - measure_burnin) /
                    (time.time() - start_time - time_spent_eval)
                )
                desc = f"speed={speed: 4.4f} sps"
                pbar.set_description(desc)
                wandb.log(
                    {
                        "vitals/speed": speed,
                    },
                    step=i,
                )

        tlog.clear()

    # mark a run as finished, and finish uploading all data (from docs)
    wandb.finish()
    logger.warn("bye")
