import torch
import torch.nn as nn
import torch.nn.functional as F
#torch.manual_seed(0)

class FullyConnectedNN(nn.Module):
    def __init__(self, num_classes, layers, use_batch_norm=False, bias=True, weight_init = "xavier"):
        super(FullyConnectedNN, self).__init__()
        self.layers = nn.ModuleList()
        input_size = layers[0]
        if weight_init=="xavier":
            init_func = torch.nn.init.xavier_uniform_
        else:
            init_func = torch.nn.init.kaiming_normal_

        for layer_size in layers[1:]:
            linear = nn.Linear(input_size, layer_size, bias=bias)
            init_func(linear.weight.data)
            self.layers.append(linear)
            if use_batch_norm:
                self.layers.append(nn.BatchNorm1d(layer_size))
            self.layers.append(nn.ReLU())
            input_size = layer_size

        self.output_layer = nn.Linear(input_size, num_classes)
        init_func(self.output_layer.weight.data)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

    def get_weights(self):
        weights = []
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                weights.append(layer.weight.detach().cpu().clone())
        weights.append(self.output_layer.weight.detach().cpu().clone())
        return weights

    def flatten_weights(self, weights):
        flat_weights = torch.cat([w.view(-1) for w in weights])
        return flat_weights


class EmbeddingConcatFFModel(nn.Module):
    def __init__(self, p, embed_dim, hidden, *args, **kwargs):
        super(EmbeddingConcatFFModel, self).__init__()
        self.embed = nn.Embedding(p, embed_dim)
        self.linear1 = nn.Linear(2 * embed_dim, hidden)  # 2 * D_EMBED because we concatenate the two embedded tokens
        self.linear2 = nn.Linear(hidden, p)
        self.init_weights()

    def forward(self, x):
        x1 = x[:, 0]
        x2 = x[:, 1]
        x1 = self.embed(x1)
        x2 = self.embed(x2)
        x = torch.cat((x1, x2), dim=1)  # Concatenate the embedding of the two tokens
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Embedding):
                nn.init.xavier_normal_(m.weight)
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def get_weights(self):
        weights = []
        weights.append(self.embed.weight.detach().cpu().clone())
        weights.append(self.linear1.weight.detach().cpu().clone())
        weights.append(self.linear2.weight.detach().cpu().clone())
        return weights

    def flatten_weights(self, weights):
        flat_weights = torch.cat([w.view(-1) for w in weights])
        return flat_weights


class ResidualConnection(nn.Module):
    def __init__(self, input_size, latent_dim_size, batch_norm = True, bias=False):
        super(ResidualConnection, self).__init__()
        self.project_in = nn.ModuleList()
        self.project_out = nn.ModuleList()

        self.project_in.append(nn.Linear(input_size, latent_dim_size, bias=bias))
        self.project_out.append(nn.Linear(latent_dim_size, input_size, bias=bias))
        if batch_norm:
            self.project_in.append(nn.BatchNorm1d(latent_dim_size))
            self.project_out.append(nn.BatchNorm1d(input_size))
        self.project_in.append(nn.ReLU())
        self.relu = nn.ReLU()

    def forward(self, x):
        x_0 = x
        for layer in self.project_in:
            x = layer(x)
        for layer in self.project_out:
            x = layer(x)
        z = self.relu(x + x_0)
        return z


class FullyConnectedResidualNN(nn.Module):
    def __init__(self, num_classes, layers, use_batch_norm=False, bias=True, weight_init = "xavier"):
        super(FullyConnectedResidualNN, self).__init__()
        self.layers = nn.ModuleList()
        input_size = layers[0]
        latent_size = layers[1]
        if weight_init=="xavier":
            init_func = torch.nn.init.xavier_uniform_
        else:
            init_func = torch.nn.init.kaiming_normal_
        self.latent_projector = nn.Linear(input_size, latent_size)
        for layer_size in layers[2:]:
            self.layers.append(ResidualConnection(latent_size, layer_size))

        self.output_layer = nn.Linear(latent_size, num_classes)

    def forward(self, x):
        x = self.latent_projector(x)
        for layer in self.layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

    def get_weights(self):
        weights = []
        weights.append(self.latent_projector.weight.detach().cpu().clone())
        for layer in self.layers:
            if isinstance(layer, ResidualConnection):
                embed_weights = layer.project_in[0].weight
                unembed_weights = layer.project_out[0].weight
                weights.append(embed_weights.detach().cpu().clone())
                weights.append(unembed_weights.detach().cpu().clone())
        weights.append(self.output_layer.weight.detach().cpu().clone())
        return weights

    def flatten_weights(self, weights):
        flat_weights = torch.cat([w.view(-1) for w in weights])
        return flat_weights


class SimpleConvNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, conv_channels=[16, 32, 64],
                 weight_init="xavier", use_batch_norm=False, bias=True):
        """
        A simple convolutional network that builds a series of convolutional blocks.
        Each block: Conv2d -> (optional BatchNorm2d) -> ReLU -> MaxPool2d.
        An AdaptiveAvgPool2d layer then produces a fixed-size output for the final FC layer.
        """
        super(SimpleConvNet, self).__init__()
        self.conv_blocks = nn.ModuleList()
        prev_channels = input_channels

        # Set the initialization function
        if weight_init == "xavier":
            init_func = torch.nn.init.xavier_uniform_
        else:
            init_func = torch.nn.init.kaiming_normal_

        # Build convolutional blocks
        for out_channels in conv_channels:
            conv = nn.Conv2d(prev_channels, out_channels, kernel_size=3, padding=1, bias=bias)
            init_func(conv.weight.data)
            block = [conv]
            if use_batch_norm:
                block.append(nn.BatchNorm2d(out_channels))
            block.append(nn.ReLU(inplace=True))
            block.append(nn.MaxPool2d(kernel_size=2))
            self.conv_blocks.append(nn.Sequential(*block))
            prev_channels = out_channels

        # Use global adaptive pooling to generate a fixed-size tensor regardless of input dimensions
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(prev_channels, num_classes)
        init_func(self.fc.weight.data)

    def forward(self, x):
        for block in self.conv_blocks:
            x = block(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)  # Flatten for the FC layer
        x = self.fc(x)
        return x

    def get_weights(self):
        """Collects weights from each Conv2d layer and the final FC layer."""
        weights = []
        for block in self.conv_blocks:
            for layer in block:
                if isinstance(layer, nn.Conv2d):
                    weights.append(layer.weight.detach().cpu().clone())
        weights.append(self.fc.weight.detach().cpu().clone())
        return weights

    def flatten_weights(self, weights):
        """Flattens a list of weight tensors into a single 1D tensor."""
        flat_weights = torch.cat([w.view(-1) for w in weights])
        return flat_weights


########################################################################
# Convolutional Residual Block & Network with Residual Connections
########################################################################

class ConvResidualBlock(nn.Module):
    def __init__(self, channels, weight_init="xavier", use_batch_norm=True, bias=False):
        """
        A convolutional residual block with two 3x3 convolutions (with optional batch normalization)
        and a skip connection.
        """
        super(ConvResidualBlock, self).__init__()
        if weight_init == "xavier":
            init_func = torch.nn.init.xavier_uniform_
        else:
            init_func = torch.nn.init.kaiming_normal_

        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=bias)
        init_func(self.conv1.weight.data)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=bias)
        init_func(self.conv2.weight.data)

        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.bn1 = nn.BatchNorm2d(channels)
            self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        if self.use_batch_norm:
            out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        if self.use_batch_norm:
            out = self.bn2(out)
        out += identity  # Skip connection
        out = self.relu(out)
        return out


class ConvResNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, conv_channels=[16, 32, 64],
                 num_residual_blocks=2, weight_init="xavier", use_batch_norm=True, bias=False):
        """
        A convolutional network that uses residual connections.
        It starts with an initial convolutional block, then builds stages.
        Each stage begins with a convolution that changes the channel number and downsamples the spatial size,
        followed by a series of ConvResidualBlock modules.
        """
        super(ConvResNet, self).__init__()
        if weight_init == "xavier":
            init_func = torch.nn.init.xavier_uniform_
        else:
            init_func = torch.nn.init.kaiming_normal_

        # Initial convolutional block
        self.conv1 = nn.Conv2d(input_channels, conv_channels[0], kernel_size=3, padding=1, bias=bias)
        init_func(self.conv1.weight.data)
        if use_batch_norm:
            self.bn1 = nn.BatchNorm2d(conv_channels[0])
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=2)

        self.layers = nn.ModuleList()
        in_channels = conv_channels[0]

        # Build stages: each stage changes the channel size and downsamples, then adds residual blocks.
        for out_channels in conv_channels[1:]:
            # Convolution to change dimensions & downsample (stride=2)
            conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=bias)
            init_func(conv_layer.weight.data)
            stage = []
            stage.append(conv_layer)
            if use_batch_norm:
                stage.append(nn.BatchNorm2d(out_channels))
            stage.append(self.relu)
            stage = nn.Sequential(*stage)
            self.layers.append(stage)

            # Append a series of residual blocks for the current stage
            for _ in range(num_residual_blocks):
                self.layers.append(ConvResidualBlock(out_channels, weight_init=weight_init,
                                                     use_batch_norm=use_batch_norm, bias=bias))
            in_channels = out_channels

        # Global pooling and final classification layer
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_channels, num_classes)
        init_func(self.fc.weight.data)

    def forward(self, x):
        x = self.conv1(x)
        if hasattr(self, 'bn1'):
            x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        for layer in self.layers:
            x = layer(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_weights(self):
        """Collects the weights from the initial conv, all Conv2d layers in the residual blocks, and the FC layer."""
        weights = []
        weights.append(self.conv1.weight.detach().cpu().clone())
        if hasattr(self, 'bn1'):
            weights.append(self.bn1.weight.detach().cpu().clone())
        for layer in self.layers:
            for mod in layer.modules():
                if isinstance(mod, nn.Conv2d):
                    weights.append(mod.weight.detach().cpu().clone())
        weights.append(self.fc.weight.detach().cpu().clone())
        return weights

    def flatten_weights(self, weights):
        """Flattens the collected weight tensors into one vector."""
        flat_weights = torch.cat([w.view(-1) for w in weights])
        return flat_weights

