import logging

import numpy as np
from typing import Dict, Optional

import ray
from src.rllib.agents.dreamer.utils import FreezeParameters
from src.rllib.evaluation import MultiAgentEpisode
from src.rllib.models.catalog import ModelCatalog
from src.rllib.policy.policy import Policy
from src.rllib.policy.policy_template import build_policy_class
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.utils.framework import try_import_torch
from src.rllib.utils.torch_ops import apply_grad_clipping
from src.rllib.utils.typing import AgentID

torch, nn = try_import_torch()
if torch:
    from torch import distributions as td

logger = logging.getLogger(__name__)


# This is the computation graph for workers (inner adaptation steps)
def compute_dreamer_loss(obs,
                         action,
                         reward,
                         model,
                         imagine_horizon,
                         discount=0.99,
                         lambda_=0.95,
                         kl_coeff=1.0,
                         free_nats=3.0,
                         log=False):
    """Constructs loss for the Dreamer objective

        Args:
            obs (TensorType): Observations (o_t)
            action (TensorType): Actions (a_(t-1))
            reward (TensorType): Rewards (r_(t-1))
            model (TorchModelV2): DreamerModel, encompassing all other models
            imagine_horizon (int): Imagine horizon for actor and critic loss
            discount (float): Discount
            lambda_ (float): Lambda, like in GAE
            kl_coeff (float): KL Coefficient for Divergence loss in model loss
            free_nats (float): Threshold for minimum divergence in model loss
            log (bool): If log, generate gifs
        """
    encoder_weights = list(model.encoder.parameters())
    decoder_weights = list(model.decoder.parameters())
    reward_weights = list(model.reward.parameters())
    dynamics_weights = list(model.dynamics.parameters())
    critic_weights = list(model.value.parameters())
    model_weights = list(encoder_weights + decoder_weights + reward_weights +
                         dynamics_weights)

    device = (torch.device("cuda")
              if torch.cuda.is_available() else torch.device("cpu"))

    # PlaNET Model Loss
    latent = model.encoder(obs)
    post, prior = model.dynamics.observe(latent, action)
    features = model.dynamics.get_feature(post)
    image_pred = model.decoder(features)
    reward_pred = model.reward(features)
    image_loss = -torch.mean(image_pred.log_prob(obs))
    reward_loss = -torch.mean(reward_pred.log_prob(reward))
    prior_dist = model.dynamics.get_dist(prior[0], prior[1])
    post_dist = model.dynamics.get_dist(post[0], post[1])
    div = torch.mean(
        torch.distributions.kl_divergence(post_dist, prior_dist).sum(dim=2))
    div = torch.clamp(div, min=free_nats)
    model_loss = kl_coeff * div + reward_loss + image_loss

    # Actor Loss
    # [imagine_horizon, batch_length*batch_size, feature_size]
    with torch.no_grad():
        actor_states = [v.detach() for v in post]
    with FreezeParameters(model_weights):
        imag_feat = model.imagine_ahead(actor_states, imagine_horizon)
    with FreezeParameters(model_weights + critic_weights):
        reward = model.reward(imag_feat).mean
        value = model.value(imag_feat).mean
    pcont = discount * torch.ones_like(reward)
    returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1],
                            lambda_)
    discount_shape = pcont[:1].size()
    discount = torch.cumprod(
        torch.cat([torch.ones(*discount_shape).to(device), pcont[:-2]], dim=0),
        dim=0)
    actor_loss = -torch.mean(discount * returns)

    # Critic Loss
    with torch.no_grad():
        val_feat = imag_feat.detach()[:-1]
        target = returns.detach()
        val_discount = discount.detach()
    val_pred = model.value(val_feat)
    critic_loss = -torch.mean(val_discount * val_pred.log_prob(target))

    # Logging purposes
    prior_ent = torch.mean(prior_dist.entropy())
    post_ent = torch.mean(post_dist.entropy())

    log_gif = None
    if log:
        log_gif = log_summary(obs, action, latent, image_pred, model)

    return_dict = {
        "model_loss": model_loss,
        "reward_loss": reward_loss,
        "image_loss": image_loss,
        "divergence": div,
        "actor_loss": actor_loss,
        "critic_loss": critic_loss,
        "prior_ent": prior_ent,
        "post_ent": post_ent,
    }

    if log_gif is not None:
        return_dict["log_gif"] = log_gif
    return return_dict


# Similar to GAE-Lambda, calculate value targets
def lambda_return(reward, value, pcont, bootstrap, lambda_):
    def agg_fn(x, y):
        return y[0] + y[1] * lambda_ * x

    next_values = torch.cat([value[1:], bootstrap[None]], dim=0)
    inputs = reward + pcont * next_values * (1 - lambda_)

    last = bootstrap
    returns = []
    for i in reversed(range(len(inputs))):
        last = agg_fn(last, [inputs[i], pcont[i]])
        returns.append(last)

    returns = list(reversed(returns))
    returns = torch.stack(returns, dim=0)
    return returns


# Creates gif
def log_summary(obs, action, embed, image_pred, model):
    truth = obs[:6] + 0.5
    recon = image_pred.mean[:6]
    init, _ = model.dynamics.observe(embed[:6, :5], action[:6, :5])
    init = [itm[:, -1] for itm in init]
    prior = model.dynamics.imagine(action[:6, 5:], init)
    openl = model.decoder(model.dynamics.get_feature(prior)).mean

    mod = torch.cat([recon[:, :5] + 0.5, openl + 0.5], 1)
    error = (mod - truth + 1.0) / 2.0
    return torch.cat([truth, mod, error], 3)


def dreamer_loss(policy, model, dist_class, train_batch):
    log_gif = False
    if "log_gif" in train_batch:
        log_gif = True

    policy.stats_dict = compute_dreamer_loss(
        train_batch["obs"],
        train_batch["actions"],
        train_batch["rewards"],
        policy.model,
        policy.config["imagine_horizon"],
        policy.config["discount"],
        policy.config["lambda"],
        policy.config["kl_coeff"],
        policy.config["free_nats"],
        log_gif,
    )

    loss_dict = policy.stats_dict

    return (loss_dict["model_loss"], loss_dict["actor_loss"],
            loss_dict["critic_loss"])


def build_dreamer_model(policy, obs_space, action_space, config):

    model = ModelCatalog.get_model_v2(
        obs_space,
        action_space,
        1,
        config["dreamer_model"],
        name="DreamerModel",
        framework="torch")

    policy.model_variables = model.variables()

    return model


def action_sampler_fn(policy, model, input_dict, state, explore, timestep):
    """Action sampler function has two phases. During the prefill phase,
    actions are sampled uniformly [-1, 1]. During training phase, actions
    are evaluated through DreamerPolicy and an additive gaussian is added
    to incentivize exploration.
    """
    obs = input_dict["obs"]

    # Custom Exploration
    if timestep <= policy.config["prefill_timesteps"]:
        logp = None
        # Random action in space [-1.0, 1.0]
        action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0
        state = model.get_initial_state()
    else:
        # Weird RLLib Handling, this happens when env rests
        if len(state[0].size()) == 3:
            # Very hacky, but works on all envs
            state = model.get_initial_state()
        action, logp, state = model.policy(obs, state, explore)
        action = td.Normal(action, policy.config["explore_noise"]).sample()
        action = torch.clamp(action, min=-1.0, max=1.0)

    policy.global_timestep += policy.config["action_repeat"]

    return action, logp, state


def dreamer_stats(policy, train_batch):
    return policy.stats_dict


def dreamer_optimizer_fn(policy, config):
    model = policy.model
    encoder_weights = list(model.encoder.parameters())
    decoder_weights = list(model.decoder.parameters())
    reward_weights = list(model.reward.parameters())
    dynamics_weights = list(model.dynamics.parameters())
    actor_weights = list(model.actor.parameters())
    critic_weights = list(model.value.parameters())
    model_opt = torch.optim.Adam(
        encoder_weights + decoder_weights + reward_weights + dynamics_weights,
        lr=config["td_model_lr"])
    actor_opt = torch.optim.Adam(actor_weights, lr=config["actor_lr"])
    critic_opt = torch.optim.Adam(critic_weights, lr=config["critic_lr"])

    return (model_opt, actor_opt, critic_opt)


def preprocess_episode(
        policy: Policy,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
        episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
    """Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
    When t=0, the resetted obs is paired with action and reward of 0.
    """
    obs = sample_batch[SampleBatch.OBS]
    new_obs = sample_batch[SampleBatch.NEXT_OBS]
    action = sample_batch[SampleBatch.ACTIONS]
    reward = sample_batch[SampleBatch.REWARDS]
    eps_ids = sample_batch[SampleBatch.EPS_ID]

    act_shape = action.shape
    act_reset = np.array([0.0] * act_shape[-1])[None]
    rew_reset = np.array(0.0)[None]
    obs_end = np.array(new_obs[act_shape[0] - 1])[None]

    batch_obs = np.concatenate([obs, obs_end], axis=0)
    batch_action = np.concatenate([act_reset, action], axis=0)
    batch_rew = np.concatenate([rew_reset, reward], axis=0)
    batch_eps_ids = np.concatenate([eps_ids, eps_ids[-1:]], axis=0)

    new_batch = {
        SampleBatch.OBS: batch_obs,
        SampleBatch.REWARDS: batch_rew,
        SampleBatch.ACTIONS: batch_action,
        SampleBatch.EPS_ID: batch_eps_ids,
    }
    return SampleBatch(new_batch)


DreamerTorchPolicy = build_policy_class(
    name="DreamerTorchPolicy",
    framework="torch",
    get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG,
    action_sampler_fn=action_sampler_fn,
    postprocess_fn=preprocess_episode,
    loss_fn=dreamer_loss,
    stats_fn=dreamer_stats,
    make_model=build_dreamer_model,
    optimizer_fn=dreamer_optimizer_fn,
    extra_grad_process_fn=apply_grad_clipping)
