import math
import copy
from typing import Optional

import torch
from torch import nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from rff.layers import GaussianEncoding, PositionalEncoding

from experiments.utils import make_coordinates


class Sine(nn.Module):
    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sin(self.w0 * x)


def params_to_tensor(params):
    return torch.cat([p.flatten() for p in params]), [p.shape for p in params]


def tensor_to_params(tensor, shapes):
    params = []
    start = 0
    for shape in shapes:
        size = torch.prod(torch.tensor(shape)).item()
        param = tensor[start : start + size].reshape(shape)
        params.append(param)
        start += size
    return tuple(params)


def wrap_func(func, shapes):
    def wrapped_func(params, *args, **kwargs):
        params = tensor_to_params(params, shapes)
        return func(params, *args, **kwargs)

    return wrapped_func


class Siren(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        w0=30.0,
        c=6.0,
        is_first=False,
        use_bias=True,
        activation=None,
    ):
        super().__init__()
        self.w0 = w0
        self.c = c
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.is_first = is_first

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c=c, w0=w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.activation = Sine(w0) if activation is None else activation

    def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if bias is not None:
            # bias.uniform_(-w_std, w_std)
            bias.zero_()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = F.linear(x, self.weight, self.bias)
        out = self.activation(out)
        return out


class INR(nn.Module):
    def __init__(
        self,
        in_dim: int = 2,
        n_layers: int = 3,
        up_scale: int = 4,
        out_channels: int = 1,
        pe_features: Optional[int] = None,
        fix_pe=True,
    ):
        super().__init__()
        hidden_dim = in_dim * up_scale

        if pe_features is not None:
            if fix_pe:
                self.layers = [PositionalEncoding(sigma=10, m=pe_features)]
                encoded_dim = in_dim * pe_features * 2
            else:
                self.layers = [
                    GaussianEncoding(
                        sigma=10, input_size=in_dim, encoded_size=pe_features
                    )
                ]
                encoded_dim = pe_features * 2
            self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_dim))
        else:
            self.layers = [Siren(dim_in=in_dim, dim_out=hidden_dim)]
        for i in range(n_layers - 2):
            self.layers.append(Siren(hidden_dim, hidden_dim))
        self.layers.append(nn.Linear(hidden_dim, out_channels))
        self.seq = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.seq(x) + 0.5


class INR_AF(INR):
    def forward(self, x: torch.Tensor):
        nodes = [x]
        for layer in self.seq:
            nodes.append(layer(nodes[-1]))
        nodes[-1] = nodes[-1] + 0.5
        return nodes


def make_functional(mod, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())

    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values, *args, **kwargs):
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values


class INRWrapper(nn.Module):
    def __init__(self, inr_kwargs) -> None:
        super().__init__()
        inr_module = INR(**inr_kwargs)
        fmodel, params = make_functional(inr_module)

        _, vshapes = params_to_tensor(params)
        self.sirens = torch.vmap(wrap_func(fmodel, vshapes))

        # NOTE hard coded maps
        self.reshape_w0 = Rearrange("b i h0 1 -> b (h0 i)")
        self.reshape_w1 = Rearrange("b h0 h1 1 -> b (h1 h0)")
        self.reshape_w2 = Rearrange("b h1 h2 1 -> b (h2 h1)")

        self.reshape_b0 = Rearrange("b h0 1 -> b h0")
        self.reshape_b1 = Rearrange("b h1 1 -> b h1")
        self.reshape_b2 = Rearrange("b h2 1 -> b h2")

    def forward(self, weights, biases, inputs=None):
        params_flat = torch.cat(
            [self.reshape_w0(weights[0]),
             self.reshape_b0(biases[0]),
             self.reshape_w1(weights[1]),
             self.reshape_b1(biases[1]),
             self.reshape_w2(weights[2]),
             self.reshape_b2(biases[2])], dim=-1)

        if inputs is None:
            inputs = make_coordinates((28, 28), params_flat.size(0))
            inputs = inputs.to(params_flat.device)

        # inputs = inputs.expand(params_flat.shape[0], -1, -1)

        out = self.sirens(params_flat, inputs)
        return out

# model = INR_AF(in_dim=2, n_layers=3, up_scale=16)
# fmodel, params = make_functional(model)

# vparams, vshapes = params_to_tensor(params)
# vwfunc = torch.vmap(wrap_func(fmodel, vshapes))
# vwfunc(torch.stack([vparams]*2), torch.ones(2,3,2))
