import math

import numpy as np
import torch
import torch.nn as nn
from tensordict import TensorDict
from torch.distributions.categorical import Categorical


class Agent(nn.Module):
    """
    Agent actor-critic network.
    Must consist of network, actor and critic submodules, of which the first is frozen for AID.
    """

    def __init__(self, observation_shape, num_actions, num_agents, num_brackets):
        super().__init__()
        self.num_brackets = num_brackets
        self.num_actions = num_actions
        self.network = nn.Sequential(
            self.layer_init(nn.Conv2d(observation_shape[2], 32, 8, stride=4)),
            nn.ReLU(),
            self.layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            self.layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            self.layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU(),
        )
        self.actor = self.layer_init(nn.Linear(512 + num_agents + num_brackets, num_actions), std=0.01)
        self.critic = self.layer_init(nn.Linear(512 + num_agents + num_brackets, 1), std=1)

    def layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        """Layer initialisation."""
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def get_value(self, obs, player_idx, tax_rates):
        """Provide a value function estimate for an observation.
        Args:
            obs (Tensor): observation
        Returns:
            (Tensor): estimate of the value function
        """

        """ Permute as torch Conv2D needs (N,C,H,W) format. """
        obs_embedding = self.network(obs.permute((0, 3, 1, 2)))

        """ Concatenate agent indicators and current tax rates to convnet output. """
        embedding_with_indicators = torch.cat([obs_embedding, player_idx, tax_rates], dim=-1)

        values = self.critic(embedding_with_indicators).flatten()

        return values

    def generate_action_and_value_no_grads(self, obs, player_idx, tax_rates, epsilon=0):
        """Provide an action sampled from policy, its log-probabilities, the policy entropy, and a value function estimate.
        All have gradients detached except the log-probabilities.
        Args:
            obs (Tensor): observation
        Returns:
            tensordict (TensorDict): a tensordict holding results of net forward pass
        """

        """ Permute as torch Conv2D needs (N,C,H,W) format. """
        obs_embedding = self.network(obs.permute((0, 3, 1, 2)))

        """ Concatenate agent indicators and current tax rates to convnet output. """
        embedding_with_indicators = torch.cat([obs_embedding, player_idx, tax_rates], dim=-1)

        logits = self.actor(embedding_with_indicators)
        values = self.critic(embedding_with_indicators).flatten()

        probs = Categorical(logits=logits)
        mixed_probs = Categorical((1 - epsilon) * probs.probs + epsilon / self.num_actions)
        action = mixed_probs.sample()

        tensordict = TensorDict(
            {
                "actions": action.detach(),
                "logprobs": mixed_probs.log_prob(action),
                "entropy": mixed_probs.entropy().detach(),
                "values": values.detach(),
            }
        )
        return tensordict

    def get_action_logprobs_and_value(self, obs, player_idx, tax_rates, action, epsilon=0):
        """Provide policy log-probabilities for a given action and observation, the policy entropy, and a value function estimate.
        All have gradients attached.
        Args:
            obs (Tensor): observation
            action (Tensor): action to provide log-probabilities for
        Returns:
            tensordict (TensorDict): a tensordict holding results of net forward pass
        """

        """ Permute as torch Conv2D needs (N,C,H,W) format. """
        obs_embedding = self.network(obs.permute((0, 3, 1, 2)))

        """ Concatenate agent indicators and current tax rates to convnet output. """
        embedding_with_indicators = torch.cat([obs_embedding, player_idx, tax_rates], dim=-1)

        logits = self.actor(embedding_with_indicators)
        values = self.critic(embedding_with_indicators).flatten()

        probs = Categorical(logits=logits)
        mixed_probs = Categorical((1 - epsilon) * probs.probs + epsilon / self.num_actions)

        tensordict = TensorDict(
            {
                "actions": action,
                "logprobs": mixed_probs.log_prob(action),
                "entropy": mixed_probs.entropy(),
                "values": values,
            }
        )
        return tensordict


class PrincipalAgent(nn.Module):
    """
    Principal policy network.
    Takes in agent sigma values and outputs tax rate suggestions using a discrete action
    head for each tax bracket.
    """

    def __init__(self, principal_obs_length, num_brackets, hidden_dim, num_hidden_layers):
        super().__init__()

        layers = []

        # Input layer
        layers.append(self.layer_init(nn.Linear(principal_obs_length, hidden_dim)))
        layers.append(nn.ReLU())

        # Hidden layers
        for _ in range(num_hidden_layers - 1):
            layers.append(self.layer_init(nn.Linear(hidden_dim, hidden_dim)))
            layers.append(nn.ReLU())

        self.mlp = nn.Sequential(*layers)

        """ One action head for each tax bracket, outputting tax suggestions over
        a discretized [0,1] interval, or a NO-OP to leave that bracket unchanged. """
        self.actor_heads = nn.ModuleList(
            [self.layer_init(nn.Linear(hidden_dim, 22), std=0.01) for _ in range(num_brackets)]
        )
        self.critic = self.layer_init(nn.Linear(hidden_dim, 1), std=1)

    def layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        """Layer initialisation."""
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def forward(self, fixed_obs, action=None):
        """If not given an action, provides an action sampled from principal policy,
        its log-probabilities and the policy entropy. If an action is given, provides
        policy log-probabilities for that action and policy entropy.

        Operates on sigma values of one parallel game, yielding tax rate suggestions for that game.

        Args:
            fixed_obs (Tensor[1]): a single one as a fixed observation
            action (Tensor[num_brackets], optional): action to provide log-probabilities for
                                                     - or if None one is sampled from policy

        Returns:
            tensordict (TensorDict): a tensordict holding results of net forward pass
        """

        hidden = self.mlp(fixed_obs)
        logits = [head(hidden) for head in self.actor_heads]
        probs = [Categorical(logits=logit) for logit in logits]
        if action is None:
            action = torch.stack([prob.sample() for prob in probs], dim=1)
        logprobs = sum(prob.log_prob(action[:, i]) for i, prob in enumerate(probs))
        entropy = sum(prob.entropy() for prob in probs)

        values = self.critic(hidden).flatten()
        tensordict = TensorDict(
            {
                "actions": action,
                "logprobs": logprobs,
                "entropy": entropy,
                "values": values,
                "distribution": torch.stack([prob.probs for prob in probs]),
            }
        )
        return tensordict


class DesignerNet(nn.Module):

    def __init__(self, principal_obs_length, id_hidden_dimension, num_brackets, sigmoid_shift, output_multiplier):
        super().__init__()
        self.output_multiplier = output_multiplier
        self.sigmoid_shift = sigmoid_shift
        self.mlp = nn.Sequential(
            nn.Linear(principal_obs_length, id_hidden_dimension),
            nn.ReLU(),
            nn.Linear(id_hidden_dimension, id_hidden_dimension),
            nn.ReLU(),
            nn.Linear(id_hidden_dimension, num_brackets),
        )

    def forward(self, obs):
        unsigmoided = self.mlp(obs)
        tax_vals = self.output_multiplier * torch.sigmoid(unsigmoided - self.sigmoid_shift)
        return tax_vals


class DesignerNetMini(nn.Module):

    def __init__(self, num_brackets, sigmoid_shift, output_multiplier):
        super().__init__()
        self.output_multiplier = output_multiplier
        self.sigmoid_shift = sigmoid_shift
        self.mlp = nn.Linear(1, num_brackets)

    def forward(self, obs):
        unsigmoided = self.mlp(obs)
        tax_vals = self.output_multiplier * torch.sigmoid(unsigmoided - self.sigmoid_shift)
        return tax_vals


class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, 1, embed_dim)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[: x.size(0)]


class LinearLayers(nn.Module):
    def __init__(self, embed_dim, dropout):
        super().__init__()
        self.c_fc = nn.Linear(embed_dim, 4 * embed_dim)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.mlp = LinearLayers(embed_dim, dropout)

    def forward(self, x):
        x = self.ln_1(x)
        attn_output, _ = self.attn(x, x, x)
        x = x + attn_output
        x = x + self.mlp(self.ln_2(x))
        return x


class DesignerNetFlexible(nn.Module):
    """
    AID Designer network.
    Takes in downsampled episode agent reward trajectories and outputs tax rates directly
    using a sigmoid final activation layer.
    Consists of attention blocks to handle various episode lengths without changing architecture.
    """

    def __init__(self, num_brackets, embed_dim=32, num_heads=4, dropout=0.2):
        super().__init__()
        self.embedding = nn.Linear(1, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim)
        self.attention_blocks = nn.Sequential(
            AttentionBlock(embed_dim, num_heads, dropout=dropout),
            AttentionBlock(embed_dim, num_heads, dropout=dropout),
        )
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, num_brackets),
        )

    def forward(self, trajectory):
        """Suggest new tax rates for an episode

        Args:
            trajectory (Tensor): concatenation of all agents' episode reward trajectories for one parallel game

        Returns:
            tax_vals: suggested new tax rates
        """
        embedding = self.embedding(trajectory.unsqueeze(-1))
        position_encoded = self.pos_encoder(embedding)
        attended = self.attention_blocks(position_encoded)
        """ attended has shape:
        (num_parallel_games, concatenated downsampled trajectory length, dimension we projected scalars to)
        We average across sequence dimension to eliminate variability in trajectory length."""
        pooled = attended.mean(dim=1)
        unsigmoided = self.mlp(pooled)
        tax_vals = torch.sigmoid(unsigmoided - 2)
        return tax_vals
