import torch.nn as nn
from GeneralModules import activation, init_xavier
import numpy as np
from debug_tools import *


def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.kaiming_uniform_(m.weight.data, a=0.25, nonlinearity='tanh')
        torch.nn.init.zeros_(m.bias.data)


class Resnet(nn.Module):
    """
    Implements the physics-informed neural network.
    """

    def __init__(self, input_dimension, output_dimension, network_architecture):
        super(Resnet, self).__init__()
        self.input_dimension = input_dimension
        self.output_dimension = output_dimension
        self.n_hidden_layers = network_architecture["n_hidden_layers"]
        self.neurons = network_architecture["neurons"]
        self.act_string = network_architecture["act_string"]
        self.retrain = network_architecture["retrain"]
        self.p = network_architecture["dropout_rate"]
        self.activation = activation(self.act_string)

        self.input_layer = nn.Linear(self.input_dimension, self.neurons)
        n_res_blocks = (self.n_hidden_layers - 1) // 2
        n_remaining_layers = (self.n_hidden_layers - 1) % 2
        self.residual_blocks = nn.ModuleList(
            [ResidualBlock(self.neurons, self.act_string, self.retrain) for _ in range(n_res_blocks)])
        if n_remaining_layers == 0:
            self.remaining_layers = None
            self.remaining_batch_layers = None
        else:
            self.remaining_layers = nn.ModuleList(
                [nn.Linear(self.neurons, self.neurons) for _ in range(n_remaining_layers)])
            self.remaining_batch_layers = nn.ModuleList(
                [nn.BatchNorm1d(self.neurons) for _ in range(n_remaining_layers)])

        self.output_layer = nn.Linear(self.neurons, self.output_dimension)
        self.dropout = nn.Dropout(self.p)
        init_xavier(self)

    def forward(self, x):
        x = self.activation(self.input_layer(x))
        for block in self.residual_blocks:
            x = self.dropout(block(x))
        for l, b in zip(self.remaining_layers, self.remaining_batch_layers):
            x = self.dropout(b(self.activation(l(x))))
        return self.output_layer(x)

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams


class ResidualBlock(nn.Module):
    """
    Defines a residual block consisting of two layers.
    """

    def __init__(self, num_neurons, activation_f, retrain):
        super(ResidualBlock, self).__init__()
        self.neurons = num_neurons
        self.act_string = activation_f
        self.activation = activation(activation_f)

        self.layer_1 = nn.Linear(self.neurons, self.neurons)
        self.batch_1 = nn.BatchNorm1d(self.neurons)

        self.layer_2 = nn.Linear(self.neurons, self.neurons)
        self.batch_2 = nn.BatchNorm1d(self.neurons)
        self.retrain = retrain
        init_xavier(self)

    def forward(self, x):
        z = self.batch_1(self.activation(self.layer_1(x)))
        z = self.batch_2(self.layer_2(z))
        z = z + x
        return self.activation(z)


class FeedForwardNN(nn.Module):

    def __init__(self, input_dimension, output_dimension, network_architecture):
        super(FeedForwardNN, self).__init__()
        self.input_dimension = input_dimension
        self.output_dimension = output_dimension
        self.n_hidden_layers = network_architecture["n_hidden_layers"]
        self.neurons = network_architecture["neurons"]
        self.act_string = network_architecture["act_string"]
        self.retrain = network_architecture["retrain"]
        self.dropout_rate = network_architecture["dropout_rate"]

        torch.manual_seed(self.retrain)

        self.input_layer = nn.Linear(self.input_dimension, self.neurons)

        self.hidden_layers = nn.ModuleList(
            [nn.Linear(self.neurons, self.neurons) for _ in range(self.n_hidden_layers - 1)])
        self.batch_layers = nn.ModuleList(
            [nn.BatchNorm1d(self.neurons) for _ in range(self.n_hidden_layers - 1)])
        self.output_layer = nn.Linear(self.neurons, self.output_dimension)

        self.activation = activation(self.act_string)
        self.dropout = nn.Dropout(self.dropout_rate)

        init_xavier(self)

    def forward(self, x):
        x = self.activation(self.input_layer(x))
        for k, (l, b) in enumerate(zip(self.hidden_layers, self.batch_layers)):
            x = b(self.activation(self.dropout(l(x))))
        return self.output_layer(x)

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams


class TrunkNet(nn.Module):
    def __init__(self, input_dimension, output_dimension, network_architecture):
        super().__init__()

        # dimensionality of y
        self.d = input_dimension
        # dimensionality of G(y)(y)
        self.p = output_dimension
        act = activation(network_architecture["act_string"])

        self.depth = network_architecture["n_hidden_layers"]
        self.width = network_architecture["neurons"]
        self.retrain = network_architecture["retrain"]
        torch.manual_seed(self.retrain)

        # setup architecture
        self.first_modules = nn.Sequential(
            nn.Linear(self.d, self.width),
            act)
        self.hidden_modules = nn.ModuleList()
        for i in range(self.depth):
            self.hidden_modules.append(
                nn.Sequential(
                    nn.Linear(self.width, self.width),
                    act))
        self.last_modules = nn.Sequential(
            nn.Linear(self.width, self.p),
            act)

        self.apply(init_weights)
        for name, m in self.named_parameters():
            if name.endswith('bias'):
                nn.init.normal_(m, mean=0, std=1e-1)
        self.last_modules[0].bias.data = torch.linspace(-1, +1, self.p)

    def forward(self, y):
        y = self.first_modules(y)

        for i, hidden_layer in enumerate(self.hidden_modules):
            y = hidden_layer(y)

        y = self.last_modules(y)
        return y


class BranchNet(nn.Module):
    def __init__(self, input_dimension, output_dimension, network_architecture):
        super().__init__()

        # number of sensors
        self.m = input_dimension
        self.p = output_dimension
        act = activation(network_architecture["act_string"])

        self.depth = network_architecture["n_hidden_layers"]
        self.width = network_architecture["neurons"]
        self.retrain = network_architecture["retrain"]
        torch.manual_seed(self.retrain)

        # setup architecture
        self.first_modules = nn.Sequential(
            nn.Linear(self.m, self.width),
            act)
        self.hidden_modules = nn.ModuleList()
        for i in range(self.depth):
            self.hidden_modules.append(
                nn.Sequential(
                    nn.Linear(self.width, self.width),
                    act))
        self.last_modules = nn.Linear(self.width, self.p)
        self.apply(init_weights)
        nn.init.kaiming_uniform_(self.first_modules[0].weight.data, nonlinearity='linear')
        if len(self.hidden_modules) > 0:
            nn.init.zeros_(self.hidden_modules[-1][0].weight.data)

    def forward(self, x):
        x = self.first_modules(x)

        identity = x
        for i, hidden_layer in enumerate(self.hidden_modules):
            x = hidden_layer(x)

        x = self.last_modules(x + identity)
        return x


class DeepOnet2(nn.Module):
    def __init__(self, branch, trunk):
        super(DeepOnet2, self).__init__()
        self.branch = branch
        self.trunk = trunk
        self.b0 = torch.nn.Parameter(torch.tensor(0.), requires_grad=True)

    def forward(self, u_, x_):
        weights = self.branch(u_)
        basis = self.trunk(x_)
        return torch.matmul(weights, basis.T) + self.b0

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams


class ShiftDeepONet(nn.Module):
    def __init__(self, branch, trunk, shift_net, scale_net, bias_net, n_basis, device):
        super().__init__()

        self.device = device
        self.p = n_basis

        self.shift_net = shift_net
        self.scale_net = scale_net

        self.b0 = torch.nn.Parameter(torch.tensor(0.), requires_grad=True)

        self.branch_net = branch
        self.trunk_net = trunk

    def forward(self, u_m, y):
        scale_result = self.scale_net(u_m)
        shift_result = self.shift_net(u_m)
        branch_result = self.branch_net(u_m)
        res = torch.zeros(u_m.shape[0], y.shape[0]).to(self.device)
        norm_fact = 1.696 / np.sqrt(self.p)

        for b in range(u_m.shape[0]):
            offset_b = y * scale_result[b] + shift_result[b]
            trunk_result_b = self.trunk_net(offset_b)
            branch_result_b = branch_result[b].view(-1, 1)
            prod_b = torch.matmul(trunk_result_b, branch_result_b)
            res_b = (norm_fact * prod_b + self.b0).T
            res[b] = res_b.view(-1, )
        return res

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams


class ShiftDeepONet2D(nn.Module):
    def __init__(self, branch, trunk, shift_net, scale_net, bias_net, n_basis, device):
        super().__init__()

        self.device = device
        self.p = n_basis

        self.shift_net = shift_net
        self.scale_net = scale_net

        self.b0 = torch.nn.Parameter(torch.tensor(0.), requires_grad=True)

        self.branch_net = branch
        self.trunk_net = trunk

    def forward(self, u_m, y):
        scale_result = self.scale_net(u_m)
        shift_result = self.shift_net(u_m)
        branch_result = self.branch_net(u_m)
        res = torch.zeros(u_m.shape[0], y.shape[0]).to(self.device)
        norm_fact = 1.696 / np.sqrt(self.p)
        x_1 = y[:, 0].view(-1, 1).to(self.device)
        x_2 = y[:, 1].view(-1, 1).to(self.device)

        for b in range(u_m.shape[0]):
            offset_b_x1 = x_1 * scale_result[b] + shift_result[b]
            offset_b_x2 = x_2 * scale_result[b] + shift_result[b]
            offset_b = torch.cat((offset_b_x1, offset_b_x2), -1)
            trunk_result_b = self.trunk_net(offset_b)
            branch_result_b = branch_result[b].view(-1, 1)
            prod_b = torch.matmul(trunk_result_b, branch_result_b)
            res_b = (norm_fact * prod_b + self.b0).T
            res[b] = res_b.view(-1, )
        return res

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams


# Branch-net for the compressible Euler equations
class Block1D(nn.Module):
    def __init__(self, inChannels, outChannels, kernel_size=1, stride=1, padding=0, activation=nn.ReLU()):
        super().__init__()
        self.conv = nn.Conv1d(inChannels, outChannels, kernel_size, padding=padding, stride=stride)
        self.activation = activation
        self.batch_norm = nn.BatchNorm1d(outChannels)

    def forward(self, x):
        return self.batch_norm(self.activation((self.conv(x))))


class Block2D(nn.Module):
    def __init__(self, inChannels, outChannels, kernel_size=1, stride=1, padding=0, activation=nn.ReLU()):
        super().__init__()
        self.conv = nn.Conv2d(inChannels, outChannels, kernel_size, padding=padding, stride=stride)
        self.activation = activation
        self.batch_norm = nn.BatchNorm2d(outChannels)

    def forward(self, x):
        return self.batch_norm(self.activation((self.conv(x))))


class BranchNetConv1D(nn.Module):
    def __init__(self, input_dimension, output_dimension, network_architecture, print=False):
        super().__init__()

        self.m = input_dimension
        self.p = output_dimension
        self.print = print
        self.act = activation(network_architecture["act_string"])

        self.depth = network_architecture["n_hidden_layers"]
        self.width = network_architecture["neurons"]
        self.retrain = network_architecture["retrain"]
        torch.manual_seed(self.retrain)

        self.first_modules = Block1D(input_dimension, self.width, kernel_size=3, padding=1, stride=2, activation=self.act)
        self.hidden_modules = nn.ModuleList()
        for i in range(self.depth):
            self.hidden_modules.append(Block1D(self.width, self.width, kernel_size=3, padding=1, stride=2, activation=self.act))
        self.l1 = nn.LazyLinear(256)
        self.last_modules = nn.Linear(256, self.p)

        self.apply(init_weights)
        nn.init.kaiming_uniform_(self.first_modules.conv.weight.data, nonlinearity='linear')
        if len(self.hidden_modules) > 0:
            nn.init.zeros_(self.hidden_modules[-1].conv.weight.data)

    def forward(self, x):
        if self.print: print(x.shape)
        x = self.first_modules(x)
        if self.print: print(x.shape)
        for i, hidden_layer in enumerate(self.hidden_modules):
            x = hidden_layer(x)
            if self.print: print(x.shape)

        if self.print: print(x.shape)
        x = nn.Flatten()(x)
        if self.print: print(x.shape)
        x = self.act(self.l1(x))
        if self.print: print(x.shape)
        x = self.last_modules(x)
        if self.print: print(x.shape)
        return x


class BranchNetConv(nn.Module):
    def __init__(self, input_dimension, output_dimension, network_architecture, print=False):
        super().__init__()

        self.m = input_dimension
        self.p = output_dimension
        self.print = print
        self.act = activation(network_architecture["act_string"])

        self.depth = network_architecture["n_hidden_layers"]
        self.width = network_architecture["neurons"]
        self.retrain = network_architecture["retrain"]
        torch.manual_seed(self.retrain)

        self.first_modules = Block2D(input_dimension, self.width, kernel_size=3, padding=1, stride=2, activation=self.act)
        self.hidden_modules = nn.ModuleList()
        for i in range(self.depth):
            self.hidden_modules.append(Block2D(self.width, self.width, kernel_size=3, padding=1, stride=2, activation=self.act))
        self.l1 = nn.LazyLinear(256)
        self.last_modules = nn.Linear(256, self.p)

        self.apply(init_weights)
        nn.init.kaiming_uniform_(self.first_modules.conv.weight.data, nonlinearity='linear')
        if len(self.hidden_modules) > 0:
            nn.init.zeros_(self.hidden_modules[-1].conv.weight.data)

    def forward(self, x):
        if self.print: print(x.shape)
        x = self.first_modules(x)
        if self.print: print(x.shape)
        for i, hidden_layer in enumerate(self.hidden_modules):
            x = hidden_layer(x)
            if self.print: print(x.shape)

        if self.print: print(x.shape)
        x = nn.Flatten()(x)
        if self.print: print(x.shape)
        x = self.act(self.l1(x))
        if self.print: print(x.shape)
        x = self.last_modules(x)
        if self.print: print(x.shape)
        return x
