import math
from re import A
from typing import List, Type

import torch
import torch.nn as nn


@torch.no_grad()
def soft_update(target, source, tau, update_idx=0):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


@torch.no_grad()
def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)


def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
    """
    Utility function for computing output of convolutions
    takes a tuple of (h,w) and returns a tuple of (h,w)
    """

    if type(h_w) is not tuple:
        h_w = (h_w, h_w)

    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)

    if type(stride) is not tuple:
        stride = (stride, stride)

    if type(pad) is not tuple:
        pad = (pad, pad)

    h = (h_w[0] + (2 * pad[0]) - (dilation * (kernel_size[0] - 1)) - 1) // stride[0] + 1
    w = (h_w[1] + (2 * pad[1]) - (dilation * (kernel_size[1] - 1)) - 1) // stride[1] + 1

    return h, w


def combine_ensemble_actions(x: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
    """
    Combine ensemble actions with state

    Cases:
    - state dim is action dim and shapes match -> return cat(state, action)
    - state is 2D, action is 3D: return replicated state concatenated with action
    """

    if x.dim() == a.dim() and x.shape[:-1] == a.shape[:-1]:
        return torch.cat((x, a), dim=-1)
    elif x.dim() == 2 and a.dim() == 3:
        return torch.cat((x.unsqueeze(0).repeat(a.shape[0], 1, 1), a), dim=-1)
    else:
        raise ValueError(f"Invalid shapes: x={x.shape}, a={a.shape}")


def orthogonal_init(m):
    """Orthogonal layer initialization."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        gain = nn.init.calculate_gain("relu")
        nn.init.orthogonal_(m.weight.data, gain)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


@torch.no_grad()
def rescale_grad(module, scale):
    """Rescale the gradient of a module by a given factor."""
    for p in module.parameters():
        if p.grad is not None:
            p.grad *= scale


class MLP(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_dim,
        hidden_layers,
        normalize_input: bool = True,
        apply_spectral_norm: bool = False,
        normalize_last_layer: bool = False,
        batch_norm: bool = False,
        nonlinearity: Type[nn.Module] = nn.ELU,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.apply_spectral_norm = apply_spectral_norm
        self.normalize_last_layer = normalize_last_layer
        self.batch_norm = batch_norm
        self.layers: List[nn.Module] = []
        if hidden_layers == 0:
            if normalize_input:
                self.layers.append(nn.LayerNorm(input_dim))
                self.layers.append(nn.Tanh())
            self.layers.append(nn.Linear(input_dim, output_dim))
        else:
            self.layers: List[nn.Module] = [nn.Linear(input_dim, hidden_dim)]
            if normalize_input:
                self.layers.append(nn.LayerNorm(hidden_dim))
                self.layers.append(nn.Tanh())
            else:
                self.layers.append(nonlinearity())
            for i in range(hidden_layers - 1):
                if apply_spectral_norm and i == hidden_layers - 2:
                    self.layers.append(
                        nn.utils.parametrizations.spectral_norm(
                            nn.Linear(hidden_dim, hidden_dim)
                        )
                    )
                else:
                    self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                if self.batch_norm:
                    self.layers.append(nn.BatchNorm1d(hidden_dim))
                self.layers.append(nonlinearity())
            if self.normalize_last_layer:
                self.layers.append(NormLayer())
            self.layers.append(nn.Linear(hidden_dim, output_dim))
        self.net = nn.Sequential(*self.layers)

    def forward(self, x):
        x = self.net(x)
        return x


class DoubleHeadMLP(MLP):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_dim,
        hidden_layers,
        normalize_input: bool = True,
        apply_spectral_norm: bool = False,
        normalize_last_layer: bool = False,
        batch_norm: bool = False,
        nonlinearity: Type[nn.Module] = nn.ELU,
    ):
        super().__init__(
            input_dim,
            output_dim,
            hidden_dim,
            hidden_layers,
            normalize_input,
            apply_spectral_norm,
            normalize_last_layer,
            batch_norm,
            nonlinearity,
        )

        if hidden_layers == 0:
            self.net = nn.Identity()
        else:
            self.layers: List[nn.Module] = [
                nn.Linear(input_dim, hidden_dim),
            ]
            if normalize_input:
                self.layers.append(nn.LayerNorm(hidden_dim))
                self.layers.append(nn.Tanh())
            else:
                self.layers.append(nonlinearity())
            for i in range(hidden_layers - 1):
                if apply_spectral_norm and i == hidden_layers - 2:
                    self.layers.append(
                        nn.utils.parametrizations.spectral_norm(
                            nn.Linear(hidden_dim, hidden_dim)
                        )
                    )
                else:
                    self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                if self.batch_norm:
                    self.layers.append(nn.BatchNorm1d(hidden_dim))
                self.layers.append(nonlinearity())
            if normalize_last_layer:
                self.layers.append(nn.BatchNorm1d(hidden_dim))
            self.net = nn.Sequential(*self.layers)
        self.head1 = nn.Linear(hidden_dim, output_dim)
        self.head2 = nn.Linear(hidden_dim, output_dim)

        self.norm = NormLayer()

    def forward(self, x):
        x = self.net(x)
        return self.head1(x), self.head2(x)


class EnsembleMLP(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_dim,
        hidden_layers,
        ensemble_sizes,
        normalize_input: bool = True,
        apply_spectral_norm: bool = False,
        spectral_norm_clip_factor: float = 0.0,
        nonlinearity: Type[nn.Module] = nn.ELU,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.ensemble_sizes = ensemble_sizes
        self.apply_spectral_norm = apply_spectral_norm
        self.spectral_norm_clip_factor = spectral_norm_clip_factor

        if hidden_layers == 0:
            self.net = EnsembleLinearLayer(
                ensemble_sizes,
                input_dim,
                output_dim,
                True,
                apply_spectral_norm,
                spectral_norm_clip_factor,
            )
        else:
            self.layers: List[nn.Module] = [
                EnsembleLinearLayer(
                    ensemble_sizes,
                    input_dim,
                    hidden_dim,
                    True,
                    False,
                    spectral_norm_clip_factor,
                ),
            ]
            if normalize_input:
                self.layers.append(nn.LayerNorm(hidden_dim))
                self.layers.append(nn.Tanh())
            else:
                self.layers.append(nonlinearity())

            for _ in range(hidden_layers - 1):
                self.layers.append(
                    EnsembleLinearLayer(
                        ensemble_sizes,
                        hidden_dim,
                        hidden_dim,
                        True,
                        apply_spectral_norm,
                        spectral_norm_clip_factor,
                    )
                )
                self.layers.append(nonlinearity())
            self.layers.append(
                EnsembleLinearLayer(
                    ensemble_sizes,
                    hidden_dim,
                    output_dim,
                    True,
                    apply_spectral_norm,
                    spectral_norm_clip_factor,
                )
            )

            self.net = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.net(x)

    def roll(self):
        for layer in self.net.children():
            if isinstance(layer, EnsembleLinearLayer):
                layer.roll()


class DoubleHeadEnsembleMLP(EnsembleMLP):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_dim,
        hidden_layers,
        ensemble_sizes,
        normalize_input: bool = True,
        apply_spectral_norm: bool = False,
        spectral_norm_clip_factor: float = 1.0,
        nonlinearity: Type[nn.Module] = nn.ELU,
    ):
        super().__init__(
            input_dim,
            output_dim,
            hidden_dim,
            hidden_layers,
            ensemble_sizes,
            normalize_input,
            apply_spectral_norm,
            spectral_norm_clip_factor,
            nonlinearity,
        )

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.ensemble_sizes = ensemble_sizes
        self.apply_spectral_norm = apply_spectral_norm
        self.spectral_norm_clip_factor = spectral_norm_clip_factor

        if hidden_layers == 0:
            self.net = EnsembleLinearLayer(
                ensemble_sizes,
                input_dim,
                hidden_dim,
                True,
                apply_spectral_norm,
                spectral_norm_clip_factor,
            )
        else:
            self.layers = [
                EnsembleLinearLayer(
                    ensemble_sizes,
                    input_dim,
                    hidden_dim,
                    True,
                    apply_spectral_norm,
                    spectral_norm_clip_factor,
                ),
            ]
            if normalize_input:
                self.layers.append(nn.LayerNorm(hidden_dim))
                self.layers.append(nn.Tanh())
            else:
                self.layers.append(nonlinearity())
            for _ in range(hidden_layers - 1):
                self.layers.append(
                    EnsembleLinearLayer(
                        ensemble_sizes,
                        hidden_dim,
                        hidden_dim,
                        True,
                        apply_spectral_norm,
                        spectral_norm_clip_factor,
                    )
                )
                self.layers.append(nonlinearity())

            self.net = nn.Sequential(*self.layers)
        self.head1 = EnsembleLinearLayer(
            ensemble_sizes,
            hidden_dim,
            output_dim,
            True,
            False,
            spectral_norm_clip_factor,
        )
        self.head2 = EnsembleLinearLayer(
            ensemble_sizes,
            hidden_dim,
            output_dim,
            True,
            False,
            spectral_norm_clip_factor,
        )
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        x = self.net(x)
        # x = self.norm(x)
        return self.head1(x), self.head2(x)

    def roll(self):
        for layer in self.net.children():
            if isinstance(layer, EnsembleLinearLayer):
                layer.roll()
        self.head1.roll()
        self.head2.roll()


class EnsembleLinearLayer(nn.Module):
    """
    Efficient linear layer for ensemble models.
    Taken from https://github.com/facebookresearch/mbrl-lib/blob/main/mbrl/models/util.py
    """

    def __init__(
        self,
        num_members: int,
        in_size: int,
        out_size: int,
        bias: bool = True,
        spectral_norm: bool = False,
        spectral_norm_clip_value: float = 1.0,
    ):
        super().__init__()
        self.num_members = num_members
        self.in_size = in_size
        self.out_size = out_size

        self.has_spectral_norm = spectral_norm
        self.spectral_norm_clip_value = spectral_norm_clip_value

        # manual implementation so spectral norm is not computed jointly on all ensemble members
        self.u = torch.randn(num_members, in_size, 1)
        self.u = nn.Parameter(  # type: ignore
            self.u
            / (
                torch.linalg.norm(self.u, dim=1, keepdim=True) + 1e-10
            )  # initialize on unit sphere
        )
        self.v = nn.Parameter(  # type: ignore
            torch.randn(num_members, out_size, 1), requires_grad=False
        )  # overwritten in first iteration

        if bias:
            self.bias = nn.Parameter(torch.zeros(self.num_members, 1, self.out_size))  # type: ignore
            self.use_bias = True
        else:
            self.use_bias = False

        self._init_weight()

    def _init_weight(self):
        weight = []
        for _ in range(self.num_members):
            x: torch.Tensor = torch.empty(self.out_size, self.in_size)
            x = nn.init.kaiming_uniform_(x, a=math.sqrt(5))
            weight.append(torch.transpose(x, 0, 1))
        self.weight = nn.Parameter(torch.stack(weight, dim=0))  # type: ignore

    def roll(self):
        with torch.no_grad():
            shifted_weight = torch.clone(self.weight[1:])
            shifted_bias = torch.clone(self.bias[1:])
            new_weight = torch.clone(self.weight[-1:])
            new_bias = torch.clone(self.bias[-1:])
            shifted_u = torch.clone(self.u[1:])
            shifted_v = torch.clone(self.v[1:])
            new_u = torch.clone(self.u[-1:])
            new_v = torch.clone(self.v[-1:])
            self.weight = nn.Parameter(torch.cat((shifted_weight, new_weight), dim=0))  # type: ignore
            self.bias = nn.Parameter(torch.cat((shifted_bias, new_bias), dim=0))  # type: ignore
            self.u = nn.Parameter(torch.cat((shifted_u, new_u), dim=0))  # type: ignore
            self.v = nn.Parameter(torch.cat((shifted_v, new_v), dim=0))  # type: ignore

    def forward(self, x):
        if self.has_spectral_norm:
            self._update_spectral_norm()
            norm = self.spectral_norm()
            w = self.weight / torch.max(
                torch.ones_like(norm), norm / (self.spectral_norm_clip_value + 1e-8)
            )
        else:
            w = self.weight
        if x.dim() == 2:
            xw = x.matmul(w)
        else:
            xw = torch.bmm(x, w)
            # xw = torch.einsum("ebd,edm->ebm", x, self.weight)
        if self.use_bias:
            return xw + self.bias
        else:
            return xw

    @torch.no_grad()
    def _update_spectral_norm(self):
        w_T = torch.transpose(self.weight, 1, 2)
        v_product = torch.bmm(w_T, self.u)
        u_product = torch.bmm(self.weight, self.v)
        self.v[:] = v_product / (torch.norm(v_product, dim=1, keepdim=True) + 1e-8)
        self.u[:] = u_product / (torch.norm(u_product, dim=1, keepdim=True) + 1e-8)

    @torch.no_grad()
    def spectral_norm(self):
        return torch.abs(
            torch.bmm(self.u.transpose(1, 2), torch.bmm(self.weight, self.v))
            .detach()
            .clone()
        )


class NormLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        norm = torch.norm(x, dim=-1, keepdim=True)
        return x / (norm + 1e-8)
