import torch
from torch import nn
import monotonicnetworks as lmn
from hetreg.models import get_activation, get_head
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from functools import partial
from .activation import ExU
from collections import OrderedDict
def stable_softplus(input):
    return F.softplus(input) + 1e-8


def get_head_activation(act_str):
    if act_str == 'exp':
        return torch.exp
    elif act_str == 'softplus':
        return stable_softplus
    else:
        raise ValueError('invalid activation')


class NaturalHead(nn.Module):

    def __init__(self, activation='softplus') -> None:
        super().__init__()
        self.act_fn = get_head_activation(activation)

    def forward(self, input):
        return torch.stack([input[:, 0], -0.5 * self.act_fn(input[:, 1])], 1)


class GaussianHead(nn.Module):

    def __init__(self, activation='softplus') -> None:
        super().__init__()
        self.act_fn = get_head_activation(activation)

    def forward(self, input):
        f1, f2 = input[:, 0], self.act_fn(input[:, 1])
        return torch.stack([f1, f2], 1)


def get_head(head_str):
    if head_str == 'natural':
        return NaturalHead
    elif head_str == 'gaussian':
        return GaussianHead
    else:
        return nn.Identity


def initialize_prior_knowledge(list_covariates, dict_prior, type='pred'):
    # if type is pred, the value of dict is a scalar {-1, 0, 1}
    # if type is corr, the value of the dict should be a dict
    def impute_prior(current_cov):
        if type == 'pred':
            return 0
        elif type == 'corr':
            dict_prior_corr = {}
            for ith_imp_corr in list_covariates:
                if ith_imp_corr == current_cov:
                    dict_prior_corr[ith_imp_corr] = 1
                else:
                    dict_prior_corr[ith_imp_corr] = 0
            return dict_prior_corr
        else:
            print("Wrong type for prior")


    for ith_cov in list_covariates:
        if ith_cov in list(dict_prior.keys()):
            continue
        else:
            dict_prior[ith_cov] = impute_prior(ith_cov)
    return dict_prior


# subnetwork for NAM, with MONO backbone
class BaseLipBlock(nn.Module):
    def __init__(self,
                 in_cov_features: int,
                 in_geo_features: int,
                 val_prior: 0,
                 hidden_features,
                 hidden_layers,
                 out_features,
                 outermost_linear=False,):
        super().__init__()

        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features
        self.input_size = self.in_cov_features + self.in_geo_features

        hidden_sizes = hidden_layers * [hidden_features]
        self.net = nn.ModuleList()

        if hidden_layers == 0:  # i.e. when depth == 0.
            # Linear Model
            self.net = torch.nn.Sequential(
                torch.nn.Linear(self.in_cov_features + self.in_geo_features, out_features))
        else:
            # MLP
            in_outs = zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)

            for i, (in_size, out_size) in enumerate(in_outs):
                # kind
                if i == 0:
                    kind = "one-inf"
                else:
                    kind = "inf"
                # add layer
                self.net.append(lmn.direct_norm(
                    torch.nn.Linear(in_size, out_size),
                    kind=kind))
                self.net.append(lmn.GroupSort(4))

            # final layer
            self.net.append(lmn.direct_norm(
                torch.nn.Linear(hidden_sizes[-1], out_features),
                kind="inf"))

        if self.in_geo_features == 0:
            mono_cons_cov = [[val_prior, 0]]
            mono_cons = mono_cons_cov
        else:
            mono_cons_geo = np.zeros((self.in_geo_features, 2), dtype=int).tolist()
            mono_cons_cov = [[val_prior, 0]]
            mono_cons = mono_cons_geo + mono_cons_cov
        model = nn.Sequential(*self.net)
        self.network = lmn.MonotonicWrapper(model, monotonic_constraints=mono_cons)  # [[0, 0], [1, 0]])

        print(self.net)

    def forward(self, model_input):
        output = self.network(model_input)
        return output




class LipBlock(nn.Module):
    def __init__(self,
                 in_cov_features: int,
                 in_geo_features: int,
                 list_covariates: list,
                 dict_prior: dict,
                 hidden_features: int,
                 hidden_layers: int,
                 out_features: int,
                 outermost_linear=False,):
        super().__init__()

        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features
        self.input_size = self.in_cov_features + self.in_geo_features
        hidden_sizes = hidden_layers * [hidden_features]
        self.net = nn.ModuleList()

        if hidden_layers == 0:  # i.e. when depth == 0.
            # Linear Model
            self.net = torch.nn.Sequential(
                torch.nn.Linear(self.input_size, out_features))
        else:
            # MLP
            in_outs = zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)

            for i, (in_size, out_size) in enumerate(in_outs):
                # kind
                if i == 0:
                    kind = "one-inf"
                else:
                    kind = "inf"
                # add layer
                self.net.append(lmn.direct_norm(
                    torch.nn.Linear(in_size, out_size),
                    kind=kind))
                self.net.append(lmn.GroupSort(4))

            # final layer
            self.net.append(lmn.direct_norm(
                torch.nn.Linear(hidden_sizes[-1], out_features),
                kind="inf"))


        mono_cons = []
        if self.in_geo_features == 1:
            mono_geo_cons = np.zeros((self.in_geo_features, 2), dtype=int).tolist()
            mono_cons += mono_geo_cons
        for ith_cov in list_covariates:
            mono_cons_cov = [[dict_prior[ith_cov], 0]]
            mono_cons += mono_cons_cov
        model = nn.Sequential(*self.net)
        self.network = lmn.MonotonicWrapper(model, monotonic_constraints=mono_cons)

        print(self.net)

    def forward(self, model_input):
        output = self.network(model_input)
        return output


#
#
#
# class BaseMLP_correlation(nn.Module):
#     def __init__(self,
#                  in_features,
#                  hidden_features,
#                  hidden_layers,
#                  out_features,
#                  dict_cor_prior,
#                  list_covariates,
#                  outermost_linear=True,):
#         super().__init__()
#
#         self.input_size = in_features
#
#         hidden_sizes = hidden_layers * [hidden_features]
#         self.net = nn.ModuleList()
#
#         if hidden_layers == 0:  # i.e. when depth == 0.
#             # Linear Model
#             self.net = torch.nn.Sequential(
#                 torch.nn.Linear(self.in_features, out_features))
#         else:
#             # MLP
#             in_outs = zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)
#
#             for i, (in_size, out_size) in enumerate(in_outs):
#                 # kind
#                 if i == 0:
#                     kind = "one-inf"
#                 else:
#                     kind = "inf"
#                 # add layer
#                 self.net.append(lmn.direct_norm(
#                     torch.nn.Linear(in_size, out_size),
#                     kind=kind))
#                 self.net.append(lmn.GroupSort(4))
#
#             # final layer
#             self.net.append(lmn.direct_norm(
#                 torch.nn.Linear(hidden_sizes[-1], out_features),
#                 kind="inf"))
#
#
#         mono_constr = []
#         for ith_cov in list_covariates:
#             mono_constr.append(dict_cor_prior[ith_cov])
#             mono_constr.append(0)
#         model = nn.Sequential(*self.net)
#         self.network = lmn.MonotonicWrapper(model, monotonic_constraints=[mono_constr])#[[1, 0, 1, 0, 1, 0]])
#
#         print(self.net)
#
#     def forward(self, model_input):
#         output = self.network(model_input)
#         return output



# class MLP(nn.Sequential):
#     def __init__(self,
#                  input_size: int,
#                  width: int,
#                  depth: int,
#                  output_size: int,
#                  activation='gelu',
#                  #activation_first='exu',
#                  dropout=0.2):
#         super(MLP, self).__init__()
#         self.input_size = input_size
#         self.width = width
#         self.depth = depth
#         hidden_sizes = depth * [width]
#         self.activation = activation
#         act = get_activation(activation)
#         self.rep_layer = f'layer{depth}'
#
#         #self.add_module('flatten', nn.Flatten())
#         if len(hidden_sizes) == 0:  # i.e. when depth == 0.
#             # Linear Model
#             self.add_module('lin_layer', nn.Linear(self.input_size, output_size, bias=True))
#         else:
#             # MLP
#             in_outs = zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)
#             for i, (in_size, out_size) in enumerate(in_outs):
#                 self.add_module(f'layer{i+1}', nn.Linear(in_size, out_size, bias=True))
#                 if dropout > 0.0:
#                     self.add_module(f'dropout{i+1}', nn.Dropout(p=dropout))
#                 self.add_module(f'{activation}{i+1}', act())
#
#             self.add_module('out_layer', nn.Linear(hidden_sizes[-1], output_size, bias=True))
#
#     def reset_parameters(self):
#         for module in self.modules():
#             if isinstance(module, nn.Linear):
#                 module.reset_parameters()
#
#     def representation(self, input):
#         activation = {}
#         def get_activation(name):
#             def hook(model, input, output):
#                 activation[name] = output.detach()
#             return hook
#         handle = getattr(self, self.rep_layer).register_forward_hook(get_activation(self.rep_layer))
#         self.forward(input)
#         rep = activation[self.rep_layer]
#         handle.remove()
#         return rep.detach()



import torch
from torch import nn
from hetreg.models import get_activation  # 用你已有的 get_activation 实现


class MLP(nn.Module):
    def __init__(self,
                 input_size: int,
                 width: int,
                 depth: int,
                 output_size: int,
                 activation='gelu',
                 dropout=0.2):
        super().__init__()
        self.input_size = input_size
        self.width = width
        self.depth = depth
        self.activation_name = activation

        act_layer = get_activation(activation)

        self.input_layer = nn.Linear(input_size, width)

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()  # 添加 LayerNorm 层

        for _ in range(depth):
            self.hidden_layers.append(nn.Sequential(
                nn.Linear(width, width),
                act_layer(),
                nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
            ))
            self.norm_layers.append(nn.LayerNorm(width))  # 每一层一个 LayerNorm

        self.output_layer = nn.Linear(width, output_size)

    def forward(self, x):
        x = self.input_layer(x)
        for layer, norm in zip(self.hidden_layers, self.norm_layers):
            residual = x
            x = layer(x)
            x = x + residual  # residual connection
            #x = norm(x)        # layer normalization
        x = self.output_layer(x)
        return x

    def reset_parameters(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                module.reset_parameters()

    def representation(self, input, layer_idx=-1):
        """Extract representation after specific hidden layer (default: last)"""
        x = self.input_layer(input)
        for i, (layer, norm) in enumerate(zip(self.hidden_layers, self.norm_layers)):
            residual = x
            x = layer(x)
            x = x + residual
            x = norm(x)
            if i == layer_idx:
                return x.detach()
        return x.detach()
class BaseLipBlock_sep(nn.Module):
    def __init__(self,
                 in_cov_features: int,
                 in_geo_features: int,
                 val_prior: int or list,
                 hidden_features,
                 hidden_layers,
                 out_features,
                 outermost_linear=False, ):
        super().__init__()

        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features
        self.input_size = self.in_cov_features + self.in_geo_features

        hidden_sizes = hidden_layers * [hidden_features]
        self.net = nn.ModuleList()

        if hidden_layers == 0:  # i.e. when depth == 0.
            # Linear Model
            self.net = torch.nn.Sequential(
                torch.nn.Linear(self.in_cov_features + self.in_geo_features, out_features))
        else:
            # MLP
            in_outs = zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)

            for i, (in_size, out_size) in enumerate(in_outs):
                # kind
                if i == 0:
                    kind = "one-inf"
                else:
                    kind = "inf"
                # add layer
                self.net.append(lmn.direct_norm(
                    torch.nn.Linear(in_size, out_size),
                    kind=kind))
                self.net.append(lmn.GroupSort(4))

            # final layer
            self.net.append(lmn.direct_norm(
                torch.nn.Linear(hidden_sizes[-1], out_features),
                kind="inf"))

        if self.in_geo_features == 0:
            mono_cons_cov = [[val_prior, ]] if isinstance(val_prior, int) else [val_prior]
            mono_cons = mono_cons_cov
        else:
            mono_cons_geo = np.zeros((self.in_geo_features, 1), dtype=int).tolist()
            mono_cons_cov = [[val_prior, ]] if isinstance(val_prior, int) else [val_prior]
            mono_cons = mono_cons_geo + mono_cons_cov
        model = nn.Sequential(*self.net)
        self.network = lmn.MonotonicWrapper(model, monotonic_constraints=mono_cons)  # [[0, ], [1, ]])


        print(self.net)

    def forward(self, model_input):
        output = self.network(model_input)
        return output



class FeatureNN(nn.Module):
    def __init__(self,
                 in_cov_features: int,
                 in_geo_features: int,
                 val_prior: 0,
                 hidden_features,
                 hidden_layers):

        super().__init__()

        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features



        self.mean = BaseLipBlock_sep(
            in_cov_features=self.in_cov_features,
            in_geo_features=self.in_geo_features,
            val_prior=val_prior,
            hidden_features=hidden_features,
            hidden_layers=hidden_layers,
            out_features=1,
            outermost_linear=True)

        self.std = MLP(
                    input_size=self.in_cov_features + self.in_geo_features ,
                    width=hidden_features,
                    depth=hidden_layers,
                    output_size=1,
                    activation='gelu')

    def forward(self, model_input):
        out_mean = self.mean(model_input)
        #std_input = torch.cat((out_mean, model_input), dim=-1)
        out_var = self.std(model_input)
        output = torch.cat((out_mean, out_var), dim=1)
        return output




class SubnetMonoNN(nn.Module):
    def __init__(self,
                 in_cov_features: int,
                 in_geo_features: int,
                 val_prior: int = 0,
                 output_size: int = 1,
                 hidden_features: int=128,
                 hidden_layers: int=1):

        super().__init__()

        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features


        # self.network = BaseLipBlock_sep(
        #     in_cov_features=self.in_cov_features,
        #     in_geo_features=self.in_geo_features,
        #     val_prior=val_prior,
        #     hidden_features=hidden_features,
        #     hidden_layers=hidden_layers,
        #     out_features=output_size * 2,
        #     outermost_linear=True)

        self.mean = BaseLipBlock_sep(
            in_cov_features=self.in_cov_features,
            in_geo_features=self.in_geo_features,
            val_prior=val_prior,
            hidden_features=hidden_features,
            hidden_layers=hidden_layers,
            out_features=output_size,
            outermost_linear=True)

        self.var = MLP(
                    input_size=self.in_cov_features + self.in_geo_features,
                    width=hidden_features,
                    depth=hidden_layers,
                    output_size=output_size,
                    activation='gelu')

    def forward(self, model_input):
        out_mean = self.mean(model_input)
        #std_input = torch.cat((out_mean, model_input), dim=-1)
        out_var = self.var(model_input)
        output = torch.cat((out_mean, out_var), dim=1)
        return output

    # def forward(self, model_input):
    #     return self.network(model_input)






class SubnetMLPNN(nn.Module):
    def __init__(self,
                 in_cov_features: int,
                 in_geo_features: int,
                 output_size: int,
                 hidden_features,
                 hidden_layers,):

        super().__init__()

        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features

        self.mean = MLP(
                    input_size=self.in_cov_features + self.in_geo_features,
                    width=hidden_features,
                    depth=hidden_layers,
                    output_size=output_size,
                    activation='gelu')


        self.std = MLP(
                    input_size=self.in_cov_features + self.in_geo_features,
                    width=hidden_features,
                    depth=hidden_layers,
                    output_size=output_size,
                    activation='gelu')

    def forward(self, model_input):
        out_mean = self.mean(model_input)
        #std_input = torch.cat((out_mean, model_input), dim=-1)
        out_std = self.std(model_input)
        output = torch.cat((out_mean, out_std), dim=1)
        return output












class BaseLipBlock_for_distribution(nn.Module):
    def __init__(self,
                 in_cov_features: int,
                 in_geo_features: int,
                 val_prior: 0,
                 hidden_features,
                 hidden_layers,
                 out_features,
                 outermost_linear=False, ):
        super().__init__()

        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features
        self.input_size = self.in_cov_features + self.in_geo_features

        hidden_sizes = hidden_layers * [hidden_features]
        self.net = nn.ModuleList()

        if hidden_layers == 0:  # i.e. when depth == 0.
            # Linear Model
            self.net = torch.nn.Sequential(
                torch.nn.Linear(self.in_cov_features + self.in_geo_features, out_features*2))
        else:
            # MLP
            in_outs = zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)

            for i, (in_size, out_size) in enumerate(in_outs):
                # kind
                if i == 0:
                    kind = "one-inf"
                else:
                    kind = "inf"
                # add layer
                self.net.append(lmn.direct_norm(
                    torch.nn.Linear(in_size, out_size),
                    kind=kind))
                self.net.append(lmn.GroupSort(4))

            # final layer
            self.net.append(lmn.direct_norm(
                torch.nn.Linear(hidden_sizes[-1], out_features*2),
                kind="inf"))

        if self.in_geo_features == 0:
            mono_cons_cov = [[val_prior, 0]]
            mono_cons = mono_cons_cov
        else:
            mono_cons_geo = np.zeros((self.in_geo_features, 2), dtype=int).tolist()
            mono_cons_cov = [[val_prior, 0]]
            mono_cons = mono_cons_geo + mono_cons_cov
        model = nn.Sequential(*self.net)
        self.network = lmn.MonotonicWrapper(model, monotonic_constraints=mono_cons)  # [[0, ], [1, ]])


        print(self.net)

    def forward(self, model_input):
        output = self.network(model_input)
        return output



class FeatureLipNN(nn.Module):
    def __init__(self,
                 in_cov_features: int,
                 in_geo_features: int,
                 val_prior: 0,
                 hidden_features,
                 hidden_layers,
                 out_features,
                 outermost_linear=False, ):

        super().__init__()

        self.in_cov_features = in_cov_features
        self.in_geo_features = in_geo_features

        self.mean_with_var = BaseLipBlock_for_distribution(
            in_cov_features=self.in_cov_features,
            in_geo_features=self.in_geo_features,
            val_prior=val_prior,
            hidden_features=hidden_features,
            hidden_layers=hidden_layers,
            out_features=out_features,
            outermost_linear=True)

    def forward(self, model_input):
        out_mean_var = self.mean_with_var(model_input)
        return out_mean_var






class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        return torch.sin(30 * input)



class BatchLinear(nn.Linear, ): #MetaModule):
    #A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
    #hypernetwork.
    __doc__ = nn.Linear.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']
        #print(input.shape)
        #print(weight.shape)

        output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
        output += bias.unsqueeze(-2)
        return output

class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30.):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        #self.bn = nn.LayerNorm(out_features)
        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features ,
                                            1 / self.in_features )
                #self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                #                            np.sqrt(6 / self.in_features) / self.omega_0)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                            np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        #print(input.shape)
        #print(self.in_features)
        #print(self.out_features)
        x = torch.sin(self.omega_0 * self.linear(input))
        #x = F.dropout(x, p=0.05, training=self.training)

        return x #/self.omega_0

    def forward_with_intermediate(self, input):
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate

class BaseDeepSDFSiren(nn.Module):
    def __init__(self,
                 in_features,
                 latent_size,
                 hidden_features,
                 hidden_layers,
                 out_features,
                 latent_in=[4],
                 outermost_linear=False,
                 first_omega_0=30,
                 hidden_omega_0=30.,
                 zero_init_last_layer=False):
        super().__init__()

        self.in_features = in_features
        self.net = []
        self.latent_in = latent_in
        self.net.append(SineLayer(in_features + latent_size, hidden_features,
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            if i + 1 in self.latent_in:
                self.net.append(SineLayer(hidden_features, hidden_features - in_features,
                                      is_first=False, omega_0=hidden_omega_0))#, composer=self.composer))
            else:
                self.net.append(SineLayer(hidden_features, hidden_features,
                                          is_first=False, omega_0=hidden_omega_0))  # , composer=self.composer))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            #with torch.no_grad():
            #    final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
            #                                 np.sqrt(6 / hidden_features) / hidden_omega_0)

            self.net.append(final_linear)
        else:
             self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features), nn.Tanh()))


        self.net = nn.Sequential(*self.net)
        print(self.net)

        if zero_init_last_layer:
            if outermost_linear:
                torch.nn.init.constant_(self.net[-1].weight, 0)
                torch.nn.init.constant_(self.net[-1].bias, 0)
            else:
                torch.nn.init.constant_(self.net[-2].weight, 0)
                torch.nn.init.constant_(self.net[-2].bias, 0)
        else:
            if outermost_linear:
                nn.utils.spectral_norm(self.net[-1])
            else:
                nn.utils.spectral_norm(self.net[-1][0])


    def forward(self, embedding, coords):
        #coords = coords.clone().detach().requires_grad_(True)  # allows to take derivative w.r.t. input
        model_input = torch.cat((coords, embedding), dim=-1)

        for net_i in range(len(self.net)-1):
            output = self.net[net_i](model_input)
            if net_i in self.latent_in:
                model_input = torch.cat((coords, output), dim=-1)
            else:
                model_input = output

        output = self.net[-1](model_input)
        return output#, coords

