import math
from os import stat
import time
from math import log
from numbers import Real

import numpy as np
import omegaconf
import torch
import torch.distributions.normal as normal
from torch.nn.utils.clip_grad import clip_grad_value_

import mbrl
from mbrl.models.one_dim_tr_model import OneDTransitionRewardModel
from mbrl.third_party.unrolled_actor_soft_critic import utils
from mbrl.third_party.unrolled_actor_soft_critic.agent import Agent, adjoint
from mbrl.third_party.unrolled_actor_soft_critic.buffer import CircularReplayBuffer


def unroll_actor(cfg : omegaconf.DictConfig, imagined_buffer: CircularReplayBuffer, dynamics_model : OneDTransitionRewardModel,
                 agent : Agent, seed=None, rng=np.random.default_rng(), device: torch.DeviceObjType = torch.device("cpu"), logger=None, step=0,
                 steps_unrolling=None):
    """
    Update the SACAgent actor by unrolling the dynamics and computing the gradients with respect to the reward function.
    This dispatches to either unroll_direct if rollout_length is <= 2, or unroll_adjoint otherwise.
    """
    # The dynamics model needs to predict a delta on all states and also a reward.
    assert dynamics_model.target_is_delta
    assert dynamics_model.no_delta_list == []
    assert dynamics_model.learned_rewards

    profile = cfg.actor_train.get("profile", False)

    max_rollout_length = int(mbrl.util.math.truncated_linear(
        *cfg.actor_train.unroll_schedule, step + 1))

    if logger:
        logger.log("train/max_rollout_length", max_rollout_length, step)

    unroll_actor_direct(cfg, imagined_buffer, dynamics_model, agent, seed=seed, rng=rng, device=device, logger=logger,
                        profile=profile, max_rollout_length=max_rollout_length, step=step, steps_unrolling=steps_unrolling)


def obsv_entropy(loc, scale):
    """Calculate the log-probability-density of sample from a normal distribution N(loc, scale)

    This implementation fixes 

    Args:
        mean ([type]): [batch, state]
        std ([type]): [batch, state]
        sample ([type]): [batch, state]

    Returns: [batch,] -- the log-probability of encountering sample.
    """

    return 0.5 + 0.5 * math.log(2 * math.pi) + torch.sum(torch.log(scale), 1)


def unroll_actor_direct(cfg : omegaconf.DictConfig, imagined_buffer: CircularReplayBuffer, dynamics_model : OneDTransitionRewardModel, agent : Agent, seed=None, rng=np.random.default_rng(), device: torch.DeviceObjType = torch.device("cpu"), logger=None, profile=False, max_rollout_length=-1, optimizer_override=None, step=0, steps_unrolling=None):
    """
    Update the UASCAgent by unrolling the dynamics and computing the gradients with respect to the reward function.
    This function implements that using PyTorch's autodiff tools, and so it runs in O(n^2) time and memory (where n == rollout_length).

    Args:
        cfg (omegaconf.DictConfig): configuration options
        imagined_buffer (uasc.buffer.CircularReplayBuffer): buffer to sample from 
        dynamics_model (mbrl.models.Model): an underlying dynamics model with signature `next_observs, rewards = model.predict(obs, actions, sample=)`
        agent (Agent): Agent with policy to optimize.
        rng (np.RandomGenerator, optional): Random generator, defaults to np.random.default_rng().
        device (torch.DeviceObjType, optional): Torch device to work on, defaults to cpu.
        logger (logging.Logger, optional): Logger.
        step (int, optional): The current step
        steps_unrolling (ShiftingGaussian, optional): if specified, this sets the number of steps to unroll.
    """
    assert max_rollout_length >= 0

    trng = torch.Generator(device=device)
    if seed is not None:
        torch.manual_seed(seed)
        trng.manual_seed(seed)
        rng = np.random.default_rng(seed=seed)

    optimizer = optimizer_override or agent.actor_optimizer

    if profile:
        print(f"Entered unroll_actor_direct")
    start = time.perf_counter()

    for j in range(cfg.actor_train.rollout_batches):
        start_obs = imagined_buffer.sample(cfg.actor_train.rollouts_per_batch, rng)[0]
        assert torch.is_tensor(start_obs)

        optimizer.zero_grad()
        # obs = torch.tensor(start_obs, dtype=torch.float32, device=device)
        obs = start_obs.to(dtype=torch.float32, device=device)
        num_batch, num_state = obs.size()

        assert cfg.actor_train.log_prob_per_action == "same"
        # If "same", the log-prob is calculated for the distribution that is used to generate the action and the same penalty is applied to all ensemble elements.
        # Other operations are far too slow, and have been removed.

        # Compute the rewards and the total log_prob with an intact execution graph linking it to obs, dynamics_model, and agent
        if cfg.actor_train.sample_action_separately:
            if step == 0:
                print("WARNING: USING BASELINE")
            rewards, total_log_prob = compute_reward_separate_action(cfg, dynamics_model, agent, obs, rollout_length=max_rollout_length, trng=trng)
        else:
            rewards, total_log_prob = compute_reward(cfg, dynamics_model, agent, obs, rollout_length=max_rollout_length, trng=trng)

        # Optional logging of encountered Q-values:
        if cfg.actor_train.log_qs_every > 0:
            if step % cfg.actor_train.log_qs_every == 0:
                obs = torch.tensor(start_obs, dtype=torch.float32, device=device)
                dist = agent.actor(obs)
                actor_Q1, actor_Q2 = agent.critic_target(obs, dist.mean)
                target_V = torch.minimum(actor_Q1, actor_Q2) - agent.alpha.detach() * dist.log_prob(dist.mean).sum(-1, keepdim=True)
                target_V = target_V.detach().cpu().numpy()

                for l in range(num_batch):
                    row_reward = rewards[:,l].detach().cpu().numpy()
                    logger.log_aux("Q-value", [step, target_V[l,0]] + row_reward.tolist())

        assert torch.all(torch.isfinite(rewards))

        # Pick the reducer:
        loss = reducer(cfg, rewards, trng=trng, step=step)

        # We want to maximize the reward, but the torch optimizer
        (-loss).mean().backward()

        if cfg.actor_train.gradient_clip > 0.0:
            clip_grad_value_(agent.actor.parameters(), cfg.actor_train.gradient_clip)

        optimizer.step()

        # Update the temperature if it is learnable. We use the same formulation as MBPO, except we 
        # scale the log_prob and by the discount, and normalize it by the total encountered discount factor.
        if cfg.temperature_train.learnable:
            agent.log_alpha_optimizer.zero_grad()
            alpha_loss = (agent.alpha * (-total_log_prob - agent.target_entropy).detach()).mean()
            alpha_loss.backward()
            agent.log_alpha_optimizer.step()

    if profile:
        print(f"\tDone {time.perf_counter()-start}")

def compute_reward(cfg, dynamics_model, agent, obs, rollout_length=-1, trng=None):
    num_ensemble = dynamics_model.model.num_members # Number of ensembles
    num_batch, num_state = obs.size()
    device = obs.device

    assert rollout_length >= 0
    assert trng is not None

    obs = obs.unsqueeze(0).expand((num_ensemble, -1, -1)) # Keep track of the state of each ensemble element separately
    current_discount = 1.0
    total_discount = sum(current_discount*(cfg.discount**i) for i in range(rollout_length + 1))

    # Track the total (discounted) log-prob for temperature tuning:
    total_log_prob = torch.zeros((num_batch,), dtype=torch.float32, device=device, requires_grad=False)
    rewards = torch.zeros((num_ensemble, num_batch), device=device)

    for i in range(rollout_length + 1):
        batch = mbrl.types.TransitionBatch(obs, None, None, None, None)

        # We want to sample a single action from all possible actions; to do that
        # we sample one state from the ensemble and get the action for that:
        assert obs.size() == (num_ensemble, num_batch, num_state)
        # Pick a random ensemble for each batch, and repeat it along the state dimension:
        chosen_ensemble = torch.randint(0, num_ensemble, (num_batch,), generator=trng, device=device).reshape(1, -1 , 1).expand((1, -1, num_state))
        # Sample the action from that state:
        dist = agent.actor(torch.gather(obs, 0, chosen_ensemble).squeeze(0))

        batch.act = dist.rsample()  # Sample from the set of actions
        assert len(batch.act.size()) == 2
        # Calculate the entropy of the action distribution
        log_prob = dist.log_prob(batch.act).sum(-1)
        entropy = agent.alpha.detach() * log_prob # Now in shape (num_batch,)

        # Accumulate this to adjust the temperature;
        # TODO: scale/check policy on this; this currently follows SAC-SVG with added discounting
        total_log_prob += log_prob.detach() * current_discount/total_discount

        if i < rollout_length:
            obs, rwd = dynamics_model.get_ensemble_distribution(batch, deterministic=True, rng=trng)
            assert rwd is not None  # Check the dynamics model also predicts the reward
            assert rwd.size() == (num_ensemble, num_batch, 1)

            # accrue the gradient: gamma^t * (reward(x_t) - entropy(actor(x_t)))
            # Since the actor_optimizer minimizes the loss, we take the negative loss:
            curr_reward = rwd.squeeze(-1) - entropy[None,:]
            assert curr_reward.size() == (num_ensemble, num_batch)
            rewards += current_discount*curr_reward # Accumulate the loss

            # Multiply the discount factor by the common ratio:
            current_discount = current_discount * cfg.discount

        else:
            # TERMINAL STEP

            # To get the accrued value from the final observation in our unrolling, we use the value function: V(x) = Q(x, agent(x)).
            # Note that we're using the slow-updating target Q here.

            # If you get memory errors, use this older version that honors the batch size. Otherwise the newer version
            # saves 1-2% time in the total computation, which adds up to half an hour per run!
            """
            for k in range(num_ensemble):
                actor_Q1, actor_Q2 = agent.critic_target(obs[k,...], batch.act)
                assert actor_Q1.size() == (num_batch, 1)
                target_V = torch.minimum(actor_Q1, actor_Q2).squeeze(-1) - entropy
                assert target_V.size() == (num_batch,)
                rewards[k,:] += target_V*current_discount
            """

            assert obs.size() == (num_ensemble, num_batch, num_state)
            assert rewards.size() == (num_ensemble, num_batch)

            # Note the use of .repeat; PyTorch's repeat/repeat_interleave correspond to Numpy's tile/repeat respectively.
            actor_Q1, actor_Q2 = agent.critic_target(obs.reshape(num_ensemble*num_batch, -1), batch.act.repeat(num_ensemble, 1))
            assert actor_Q1.size() == (num_ensemble*num_batch, 1)
            target_V = torch.minimum(actor_Q1, actor_Q2).squeeze(-1) # - entropy
            target_V = target_V.reshape(num_ensemble, num_batch) - entropy[None, :]
            rewards += target_V*current_discount

    return rewards, total_log_prob

def compute_reward_separate_action(cfg, dynamics_model, agent, obs, rollout_length=-1, trng=None):
    """
    This is a baseline for comparison, not part of our method.

    In this, we see what happens if each ensemble element samples methods independently.
    """

    num_ensemble = dynamics_model.model.num_members # Number of ensembles
    num_batch, num_state = obs.size()
    device = obs.device

    assert rollout_length >= 0
    assert trng is not None

    obs = obs.unsqueeze(0).expand((num_ensemble, -1, -1)) # Keep track of the state of each ensemble element separately
    current_discount = 1.0
    total_discount = sum(current_discount*(cfg.discount**i) for i in range(rollout_length + 1))

    # Track the total (discounted) log-prob for temperature tuning:
    total_log_prob = torch.zeros((num_ensemble, num_batch,), dtype=torch.float32, device=device, requires_grad=False)
    rewards = torch.zeros((num_ensemble, num_batch), device=device)

    for i in range(rollout_length + 1):
        batch = mbrl.types.TransitionBatch(obs, None, None, None, None)

        entropy = torch.zeros_like(total_log_prob)
        assert obs.size() == (num_ensemble, num_batch, num_state)
        actions = []

        for k in range(num_ensemble):
            dist = agent.actor(obs[k,...])
            actions.append(dist.rsample())
            # Calculate the entropy of the action distribution
            log_prob = dist.log_prob(actions[-1]).sum(-1)
            total_log_prob[k,:] += log_prob.detach() * current_discount/total_discount
            entropy[k,:] = agent.alpha.detach() * log_prob # Now in shape (num_batch,)

        # Flatten the actions into a single tensor:
        batch.act = torch.stack(actions, 0)
        assert len(batch.act.size()) == 3 

        if i < rollout_length:
            obs, rwd = dynamics_model.get_ensemble_distribution(batch, deterministic=True, rng=trng)
            assert rwd is not None  # Check the dynamics model also predicts the reward
            assert rwd.size() == (num_ensemble, num_batch, 1)

            # accrue the gradient: gamma^t * (reward(x_t) - entropy(actor(x_t)))
            # Since the actor_optimizer minimizes the loss, we take the negative loss:
            curr_reward = rwd.squeeze(-1) - entropy
            assert curr_reward.size() == (num_ensemble, num_batch)
            rewards += current_discount*curr_reward # Accumulate the loss

            # Multiply the discount factor by the common ratio:
            current_discount = current_discount * cfg.discount

        else:
            assert obs.size() == (num_ensemble, num_batch, num_state)
            assert rewards.size() == (num_ensemble, num_batch)

            # Note the use of .repeat; PyTorch's repeat/repeat_interleave correspond to Numpy's tile/repeat respectively.
            actor_Q1, actor_Q2 = agent.critic_target(obs.reshape(num_ensemble*num_batch, -1), batch.act.reshape(num_ensemble*num_batch, -1))
            assert actor_Q1.size() == (num_ensemble*num_batch, 1)
            target_V = torch.minimum(actor_Q1, actor_Q2).squeeze(-1) # - entropy
            target_V = target_V.reshape(num_ensemble, num_batch) - entropy
            rewards += target_V*current_discount

    return rewards, total_log_prob

def reducer(cfg, rewards, trng=None, step=-1):
    num_ensemble, num_batch = rewards.size()

    # For each state, we pick the reward corresponding to the ensemble with the best reward and use that.
    # Since the optimizer minimizes the loss, we negate it
    if cfg.actor_train.unroll_merge == "max":
        loss, _ = torch.max(rewards, 0)
    elif cfg.actor_train.unroll_merge == "mean":
        loss = torch.mean(rewards, 0)
    elif cfg.actor_train.unroll_merge == "median":
        loss, _ = torch.median(rewards, 0)
    elif cfg.actor_train.unroll_merge == "min":
        loss, _ = torch.min(rewards, 0)
    elif cfg.actor_train.unroll_merge == "random":
        chosen_ensemble = torch.randint(0, num_ensemble, (num_batch,), generator=trng, device=rewards.device).reshape(1, -1)
        loss = torch.gather(rewards, 0, chosen_ensemble).squeeze(0)
        assert loss.size() == (num_batch,)
    elif cfg.actor_train.unroll_merge == "max-until":
        # Max until we reach the step threshold, then mean
        assert cfg.actor_train.unroll_merge_max_until > 0
        if step < cfg.actor_train.unroll_merge_max_until:
            loss, _ = torch.max(rewards, 0)
        else:
            loss = torch.mean(rewards, 0)
    elif cfg.actor_train.unroll_merge == "topk":
        assert 0 < cfg.actor_train.unroll_merge_top_k < num_ensemble
        # Get the top k from each ensemble, and then take the mean of that:
        rwd_topk, _ = torch.topk(rewards, cfg.actor_train.unroll_merge_top_k, dim=0, sorted=False)
        loss = torch.mean(rwd_topk, 0)
    elif cfg.actor_train.unroll_merge == "mellowmax":
        omega = np.exp(np.exp(1)) # e^e, because why not?
        loss = (torch.logsumexp(omega*rewards, dim=0) - np.log(num_ensemble))/omega
    else:
        raise RuntimeError(f"actor_train.unroll_merge value {cfg.actor_train.unroll_merge} not recognized!")

    return loss

def debug_Qvalues(cfg, dynamics_model, agent, start_obs, max_rollout_length=-1, device=None):
    assert max_rollout_length > 0
    assert device is not None

    obs = torch.tensor(start_obs, dtype=torch.float32, device=device)
    trng = torch.Generator(device=device)

    with torch.no_grad():
        target_V0, _ = compute_reward(cfg, dynamics_model, agent, obs, rollout_length=0, trng=trng)
        target_VH, _ = compute_reward(cfg, dynamics_model, agent, obs, rollout_length=max_rollout_length, trng=trng)

    return target_V0.cpu().numpy(), target_VH.cpu().numpy()
