import math
from typing import Callable, List, Tuple, Union

import torch
import torch.distributions as pyd
from torch.distributions.utils import _standard_normal

from lambda_ac.rl_types import (
    EncoderActorModule,
    EncoderCriticModule,
    EncoderModelNetwork,
    EncoderOutput,
    FeatureInput,
    ModelTrajectory,
)
from lambda_ac.util.distributions import TanhGaussian


def rollout_model_with_actions(
    next_feature_encoding: EncoderOutput,
    actions: torch.Tensor,
    done: torch.Tensor,
    model: EncoderModelNetwork,
    discretize_done: bool = False,
    predict_done: bool = False,
    detach_encoder: bool = False,
    return_first: bool = True,
) -> ModelTrajectory:
    depth = actions.shape[1]

    next_feature = FeatureInput.from_output(next_feature_encoding)
    if detach_encoder:
        next_feature.detach()
    if return_first:
        features: List[FeatureInput] = [next_feature]
        encodings: List[EncoderOutput] = [next_feature_encoding]
    else:
        features: List[FeatureInput] = []
        encodings: List[EncoderOutput] = []

    rewards = []
    mask = torch.logical_not(done[:, 0].view(-1, 1))
    running_model_mask = mask.clone()
    masks = [torch.ones_like(mask)]
    raw_dones = []
    model_actions = [actions[:, i] for i in range(depth)]

    for model_action in model_actions:
        # model_action = model_action
        next_feature_dist, next_reward, model_done = model(next_feature, model_action)
        next_feature_encoding = next_feature_dist
        next_feature = FeatureInput.from_output(next_feature_dist)
        next_reward = next_reward

        features.append(next_feature)
        rewards.append(next_reward)
        encodings.append(next_feature_encoding)

        if predict_done:
            if discretize_done:
                masks.append(running_model_mask * (0.0 > model_done))
            else:
                masks.append(running_model_mask * (1.0 - torch.sigmoid(model_done)))
            running_model_mask = running_model_mask * (0.0 > model_done)
        else:
            masks.append(torch.logical_not(done[:, 0]))
        raw_dones.append(model_done)

    combined_features = FeatureInput.from_list(features, dim=1)
    rewards = torch.stack(rewards, dim=1)
    masks = torch.stack(masks, dim=1)
    model_actions = torch.stack(model_actions, dim=1)
    raw_dones = torch.stack(raw_dones, dim=1)

    return ModelTrajectory(
        length=depth,
        states=combined_features,
        actions=model_actions,
        rewards=rewards,
        masks=masks,
        log_probs=torch.zeros_like(rewards),
        encodings=encodings,
        done_predictions=raw_dones,
    )


def rollout_model_with_policy(
    next_feature_encoding: Union[EncoderOutput, FeatureInput],
    done: torch.Tensor,
    model: EncoderModelNetwork,
    policy: EncoderActorModule,
    depth: int,
    discretize_done: bool = False,
    predict_done: bool = False,
    detach_encoder: bool = False,
    add_exploration_noise: bool = False,
) -> ModelTrajectory:
    # hacky but necessary since I cannot stack ModelEncoderOutputs due to the distributions
    # would be easier with a nicer generic tensortuple
    if isinstance(next_feature_encoding, FeatureInput):
        next_feature = next_feature_encoding
    else:
        next_feature = FeatureInput.from_output(next_feature_encoding)
    if detach_encoder:
        next_feature.detach()
    features: List[FeatureInput] = [next_feature]
    rewards = []
    log_probs = []
    mask = torch.logical_not(done[:, 0].view(-1, 1))
    running_model_mask = mask.clone().view(-1, 1)
    masks = [torch.ones_like(mask).view(-1, 1)]
    model_actions = []

    for i in range(depth):
        model_action, model_log_prob, _ = policy.head.forward_sample_log_prob(
            next_feature
        )
        if add_exploration_noise:
            model_action = model_action + torch.randn_like(model_action) * 0.1
        model_actions.append(model_action.clone())
        next_feature_dist, next_reward, model_done = model(next_feature, model_action)
        next_feature = FeatureInput.from_output(next_feature_dist)
        next_reward = next_reward

        features.append(next_feature)
        rewards.append(next_reward)
        log_probs.append(model_log_prob)

        if predict_done:
            if discretize_done:
                masks.append(running_model_mask * (0.0 > model_done))
            else:
                masks.append(running_model_mask * (1.0 - torch.sigmoid(model_done)))
            running_model_mask = running_model_mask * (0.0 > model_done)
        else:
            masks.append(torch.logical_not(done[:, 0].view(-1, 1)))

    features_encoded = torch.stack([f.encoded for f in features], dim=1)
    features_hidden = torch.stack([f.hidden for f in features], dim=1)
    combined_features = FeatureInput.from_tensors(features_encoded, features_hidden)
    rewards = torch.stack(rewards, dim=1)
    log_probs = torch.stack(log_probs, dim=1)
    masks = torch.stack(masks, dim=1)
    model_actions = torch.stack(model_actions, dim=1)

    return ModelTrajectory(
        length=depth,
        states=combined_features,
        actions=model_actions,
        rewards=rewards,
        masks=masks,
        log_probs=log_probs,
        encodings=None,
        done_predictions=None,
    )


def evaluate_trajectory(
    trajectory: ModelTrajectory,
    policy: EncoderActorModule,
    critic: EncoderCriticModule,
    discount: torch.Tensor,
    alpha: torch.Tensor,
    reduction: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    use_td_average: bool = False,
    add_first: bool = False,
) -> Tuple[FeatureInput, torch.Tensor, torch.Tensor]:
    length = trajectory.length
    features = trajectory.states
    actions = trajectory.actions
    rewards = trajectory.rewards
    masks = trajectory.masks
    log_probs = trajectory.log_probs

    # Compute Q(s, a) for all model steps
    q_actions, q_log_probs, _ = policy.head.forward_sample_log_prob(features)
    qf1, qf2 = critic.head(features, q_actions)
    qs = reduction(qf1, qf2)

    if not add_first:
        # Remove first log probability due to off by one setting in SAC
        # Q = R + gamma * (V(s') - alpha * log pi(a|s'))
        # Q = R + gamma * (r - alpha * log pi(a|s')) + gamma ** 2 * (V(s'') - alpha * log pi(a''|s''))
        log_probs[:, 0] = 0.0

    # apply mask to rewards and qs
    rewards = rewards * masks[:, :-1]
    log_probs = log_probs * masks[:, :-1]
    qs = qs * masks
    q_log_probs = q_log_probs * masks

    # compute the discount factor
    exponent = torch.ones_like(qs).cumsum(1) - 1.0
    discount = discount.unsqueeze(-1) ** exponent

    # apply discount and log_probs
    rewards = (rewards - alpha.detach() * log_probs) * discount[:, :-1]
    rewards = torch.cat([torch.zeros_like(rewards[:, 0:1]), rewards], dim=1)
    qs = (qs - alpha.detach() * q_log_probs) * discount

    # compute n-step reward sums
    rewards = torch.cumsum(rewards, dim=1)
    targets = rewards + qs

    if not add_first:
        targets = targets[:, 1:]

    if use_td_average:
        # compute TD average
        targets = targets.mean(1, keepdim=True)
    else:
        targets = targets[:, -1:]

    return features[:, 0], actions[:, 0], targets[:, 0]


def rollout_with_vf_target(
    states: torch.Tensor,
    actions: torch.Tensor,
    rewards: torch.Tensor,
    next_states: torch.Tensor,
    done: torch.Tensor,
    mask: torch.Tensor,
    discount: torch.Tensor,
    alpha: torch.Tensor,
    model: EncoderModelNetwork,
    policy: EncoderActorModule,
    critic: EncoderCriticModule,
    depth: int,
    reduction: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.min,
    use_model: bool = False,
    use_muzero: bool = False,
    add_exploration_noise: bool = False,
    start_from_zero: bool = False,
):
    trajectory = rollout_model_with_actions(
        model.encoder(states),
        actions,
        done[:, :1],
        model,
        discretize_done=False,
        predict_done=True,
        detach_encoder=False,
        return_first=True,
    )
    idx_aux = torch.arange(rewards.shape[0]).to(rewards.device)
    idx = torch.randint(0, trajectory.length, (rewards.shape[0],)).to(rewards.device)
    if start_from_zero:
        idx = torch.zeros_like(idx)
    with torch.no_grad():
        if use_muzero or start_from_zero:
            Q, actions = get_real_target(
                next_states[idx_aux, idx],
                actions[idx_aux, idx],
                rewards[idx_aux, idx],
                done[idx_aux, idx],
                discount,
                alpha,
                model,
                policy,
                critic,
                reduction,
                depth,
                use_model=use_model,
                add_exploration_noise=add_exploration_noise,
            )
        elif use_model:
            Q, _, actions = get_model_target(
                trajectory.states[idx_aux, idx],
                done[idx_aux, idx],
                discount,
                alpha,
                model,
                policy,
                critic,
                depth,
                reduction,
                discretize_done=True,
                predict_done=False,
                get_policy_target=False,
                detach_encoder=False,
                use_td_average=True,
                add_exploration_noise=add_exploration_noise,
            )
        else:
            idx[:] = 0
            Q, actions = get_real_target(
                next_states[idx_aux, idx],
                actions[idx_aux, idx],
                rewards[idx_aux, idx],
                done[idx_aux, idx],
                discount,
                alpha,
                model,
                policy,
                critic,
                reduction,
                depth,
                use_model=use_model,
            )
    return Q, actions, trajectory, idx


@torch.no_grad()
def get_real_target(
    states: torch.Tensor,
    actions: torch.Tensor,
    rewards: torch.Tensor,
    done: torch.Tensor,
    discount: torch.Tensor,
    alpha: torch.Tensor,
    model: EncoderModelNetwork,
    policy: EncoderActorModule,
    critic: EncoderCriticModule,
    reduction: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.min,
    depth: int = 1,
    use_model: bool = False,
    add_exploration_noise: bool = False,
):
    if use_model:
        Q, _, _ = get_model_target(
            states,
            done,
            discount,
            alpha,
            model,
            policy,
            critic,
            depth,
            reduction,
            discretize_done=True,
            predict_done=False,
            get_policy_target=True,
            detach_encoder=False,
            use_td_average=True,
            add_exploration_noise=add_exploration_noise,
        )
    else:
        next_feature_encoding: EncoderOutput = model.encoder(states)
        with torch.no_grad():
            next_features = FeatureInput.from_output(next_feature_encoding)
            next_actions, next_log_probs, _ = policy.head.forward_sample_log_prob(
                next_features
            )
            qf1, qf2 = critic.head.forward(next_features, next_actions)
            Q = reduction(qf1, qf2) - discount * alpha.detach() * next_log_probs
    target = rewards + discount * (1 - done) * Q
    return target, actions


def get_model_target(
    state: Union[torch.Tensor, EncoderOutput, FeatureInput],
    done: torch.Tensor,
    discount: torch.Tensor,
    alpha: torch.Tensor,
    model: EncoderModelNetwork,
    policy: EncoderActorModule,
    critic: EncoderCriticModule,
    depth: int,
    reduction: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    discretize_done: bool = False,
    predict_done: bool = False,
    get_policy_target: bool = False,
    detach_encoder: bool = False,
    add_exploration_noise: bool = False,
    use_td_average: bool = False,
) -> Tuple[torch.Tensor, FeatureInput, torch.Tensor]:
    if isinstance(state, torch.Tensor):
        encoding = model.encoder(state)
    else:
        encoding = state
    trajectory = rollout_model_with_policy(
        encoding,
        done,
        model,
        policy,
        depth,
        discretize_done=discretize_done,
        predict_done=predict_done,
        detach_encoder=detach_encoder,
        add_exploration_noise=add_exploration_noise,
    )

    features, actions, targets = evaluate_trajectory(
        trajectory,
        policy,
        critic,
        discount,
        alpha,
        reduction,
        use_td_average=use_td_average,
        add_first=get_policy_target,
    )
    return targets, features, actions
