import torch
import torch.nn as nn
import torch.nn.functional as F
from backbone import my_backbone


class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


class DynamicGPE(nn.Module):
    def __init__(self, hidden_dim=64, num_classes=10, num_prototypes=5,
                 old_prototypes=None):
        super().__init__()
        # feature extractor
        self.backbone = my_backbone()
        output_dim = self.backbone.output_dim

        self.linear = MLP(output_dim, output_dim // 2, hidden_dim, 3)

        # learnable prototypes
        self.new_prototypes = torch.nn.Parameter(torch.rand(num_classes, num_prototypes, hidden_dim), requires_grad=True)
        if old_prototypes is not None:
            self.old_prototypes = torch.nn.Parameter(old_prototypes, requires_grad=True)
        else:
            self.old_prototypes = None

        self.norm = nn.InstanceNorm1d(hidden_dim)

    def get_logits(self, x):
        new_distance = torch.sqrt(torch.sum(torch.square(x[:, None, None, :] - self.new_prototypes), dim=-1))
        min_new_distance = torch.min(new_distance, dim=-1)[0]
        if self.old_prototypes is not None:
            old_distance = torch.sqrt(torch.sum(torch.square(x[:, None, None, :] - self.old_prototypes), dim=-1))
            min_old_distance = torch.min(old_distance, dim=-1)[0]
            distance = torch.min(torch.cat([min_new_distance.unsqueeze(-1), min_old_distance.unsqueeze(-1)], dim=-1), dim=-1)[0]
        else: distance = min_new_distance
        return -distance
    

    def forward(self, x):
        # x: T, channel
        feat = self.backbone(x)  # batch, channel
        feat = self.linear(feat) # batch, hidden_dim
        feat = self.norm(feat)   # batch, hidden_dim
        logits = self.get_logits(feat)  # calculate L2 distance
        return logits


class mlpGPE(nn.Module):
    def __init__(self, hidden_dim=64, num_classes=10, num_prototypes=5,
                 old_prototypes=None):
        super(mlpGPE, self).__init__()

        self.fc1 = nn.Linear(784, 100)
        self.fc2 = nn.Linear(100, hidden_dim)

        self.features = nn.Sequential(
            self.fc1,
            nn.ReLU(),
            self.fc2,
            nn.ReLU(),
        )

        # learnable prototypes
        self.new_prototypes = torch.nn.Parameter(torch.rand(num_classes, num_prototypes, hidden_dim), requires_grad=True)
        if old_prototypes is not None:
            self.old_prototypes = torch.nn.Parameter(old_prototypes, requires_grad=True)
        else:
            self.old_prototypes = None

        self.norm = nn.InstanceNorm1d(hidden_dim)
    

    def get_logits(self, x):
        new_distance = torch.sqrt(torch.sum(torch.square(x[:, None, None, :] - self.new_prototypes), dim=-1))
        min_new_distance = torch.min(new_distance, dim=-1)[0]
        if self.old_prototypes is not None:
            old_distance = torch.sqrt(torch.sum(torch.square(x[:, None, None, :] - self.old_prototypes), dim=-1))
            min_old_distance = torch.min(old_distance, dim=-1)[0]
            distance = torch.min(torch.cat([min_new_distance.unsqueeze(-1), min_old_distance.unsqueeze(-1)], dim=-1), dim=-1)[0]
        else: distance = min_new_distance
        return -distance


    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        feat = self.features(x)
        feat = self.norm(feat)
        logits = self.get_logits(feat)
        return logits


if __name__ == '__main__':
    x = torch.rand(32, 1, 28, 28)
    # old_prototypes = torch.rand(10, 20, 64)
    old_prototypes = None
    model = DynamicGPE(old_prototypes=old_prototypes)
    y = model(x)
    print(y.shape)
    
    



