from typing import Callable, Tuple, Optional
import types
import math
import torch
import torch.nn as nn


class Encoder(torch.nn.Module):
    def __init__(self, in_channels: int, n_components: int, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Encoder, self).__init__()

        self.weight = nn.Parameter(torch.empty((n_components, in_channels), **factory_kwargs))
        self.bias = nn.Parameter(torch.empty(n_components, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        x = torch.einsum("...i,ni->...n", x, self.weight) + self.bias  # [B, ..., in_channels], [n_components, in_channels], [n_components] -> [B, ..., n_components]
        return torch.nn.functional.relu(x)


class head(torch.nn.Module):
    def __init__(self, n_components: int, out_channels: int, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(head, self).__init__()

        self.weight = nn.Parameter(torch.empty((out_channels, n_components), **factory_kwargs))
        self.bias = nn.Parameter(torch.empty(out_channels, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        x = torch.einsum("...n,on->...o", x, self.weight) + self.bias  # [B, ..., n_components], [out_channels, n_components], [out_channels] -> [B, ..., out_channels]
        return x


class RemainderModel(nn.Module):
    def __init__(self, in_channels: int, n_components:int, out_channels: int, dim: int):
        super(RemainderModel, self).__init__()
        self.dim = dim
        self.encoder = Encoder(in_channels=in_channels, n_components=n_components)
        self.head = head(n_components=n_components, out_channels=out_channels)

    def forward(self, x):
        if x.dim() > 2:
            dims = list(range(self.dim)) + list(range(self.dim+1, x.dim())) + [self.dim]
            x = torch.permute(x, dims=dims)

        x = self.encoder(x)
        x = self.head(x)

        if x.dim() > 2:
            dims = list(range(self.dim)) + [x.dim()-1] + list(range(self.dim, x.dim()-1))
            x = torch.permute(x, dims=dims)

        return x


def add_remainder(model: nn.Module, forward: Callable, in_channels: int, n_components: int, out_channels: int, dim=1) -> nn.Module:
    """
    Add a remainder branch to a given model, integrate custom layers, and modify its forward method.

              +-----+                             +-----+                                  +-----+
    input --> | ... | --+-----------------------> | ... | ---------------------------+---> | ... | --> output
              +-----+   |                         +-----+                            |     +-----+
                        |                                                            |
                        |    +---------------------------------------------------+   |
                        +--> |                   remainder                       | --+
                             | +-----+     +------+     +---------+     +------+ |
                             | | PCA | --> | Norm | --> | Encoder | --> | Head | |
                             | +-----+     +------+     +---------+     +------+ |
                             +---------------------------------------------------+

    Args:
        model (nn.Module): The original model to which the remainder branch is added.
        forward (Callable): A function that defines the backbone part of the model.
        train_loader (DataLoader): DataLoader used for training the model.
        dof (int): Degree of freedom for initializing the remainder mesh.
        out_channels (int): Output channels of the remainder model.
        none_reduction_criterion: Loss function with parameter `reduction="none"`.
        topk (int): A integer for selecting top-k counts
        rr (float): A rate for regularization item

    Returns:
        nn.Module: The modified model with the remainder branch.
    """
    device = next(model.parameters()).device
    remainder = RemainderModel(in_channels, n_components, out_channels, dim).to(device)
    print("Number of parameters:", sum([p.numel() for p in remainder.parameters()]))
        
    model.forward = types.MethodType(forward, model)
    model.remainder = remainder

    # Overwrite the `train` method
    def train(self, mode=True):
        super(self.__class__, self).train(mode)  # Call the parent class train method
        if mode:  # Freeze the original model and unfreeze the remainder branch
            for name, module in self.named_children():
                if name == 'remainder':
                    module.train()
                else:
                    module.eval()
        else:
            for name, module in self.named_children():
                module.eval()

    model.train = types.MethodType(train, model)

    return model


def mobilenet_v2_wrapper(model: nn.Module, n_components: int):
    # Baseline: "acc@1": 71.878, "acc@5": 90.286
    in_channels, out_channels, dim = 64, 1280, 1

    # Overwrite the `forward` method
    # Example 1:
    #     `n_components`: 16, `lr`: 0.01
    #     Number of parameters: 44320
    #     Acc@1 71.882 Acc@5 90.258
    def forward(self, x):
        for k, module in enumerate(self.features):
            if k == 11:
                re = torch.nn.functional.avg_pool2d(self.remainder(x), 2, 2)
            x = module(x)
        x = x + re
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)        
        output = self.classifier(x)
        
        return output

    return add_remainder(model, forward, in_channels, n_components, out_channels, dim)


def resnet_wrapper(model: nn.Module, n_components: int):
    # Baseline: "acc@1": 76.130, "acc@5": 92.862
    in_channels, out_channels, dim = 1024, 2048, 1

    # Overwrite the `forward` method
    # Example 1:
    #     `n_components`: 16, `lr`: 0.01
    #     Number of parameters: 51216
    #     Acc@1 76.162 Acc@5 92.974
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x) + torch.nn.functional.avg_pool2d(self.remainder(x), 2, 2)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        output = self.fc(x)
        
        return output
        
    return add_remainder(model, forward, in_channels, n_components, out_channels, dim)


def resnext_wrapper(model: nn.Module, n_components: int):
    # Baseline: "acc@1": 77.618, "acc@5": 93.698
    in_channels, out_channels, dim = 1024, 2048, 1

    # Overwrite the `forward` method
    # Example 1:
    #     `n_components`: 16, `lr`: 0.01
    #     Number of parameters: -----
    #     Acc@1 77.590 Acc@5 93.688
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x) + torch.nn.functional.avg_pool2d(self.remainder(x), 2, 2) # Remainder first and pooling next: Acc@1 77.626 Acc@5 93.682

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        output = self.fc(x)
        
        return output
        
    return add_remainder(model, forward, in_channels, n_components, out_channels, dim)


def transformer_wrapper(model: nn.Module, n_components: int):
    in_channels, out_channels, dim = 768, 1000, 1

    # Overwrite the `forward` method
    # Example 1:
    #     `n_components`: 16, `dof`: 16 `lr`: 0.00001
    #     Number of parameters: 208912
    #     Acc@1 80.888 Acc@5 95.208
    def forward(self, x):
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]
        
        output = self.heads(x) + self.remainder(x)
        
        # reg = self.remainder.compute_l1()

        return output
    
    return add_remainder(model, forward, in_channels, n_components, out_channels, dim)


def wrapper(name: str, model: nn.Module, n_components: int):
    if name == "mobilenet_v2":
        return mobilenet_v2_wrapper(model, n_components)
    elif name == "resnet18":
        return resnet_wrapper(model, 512, n_components)
    elif name == "resnet50":
        return resnet_wrapper(model, n_components)
    elif name == "resnext50_32x4d":
        return resnext_wrapper(model, n_components)
    elif name in ["vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", "vit_h_14"]:
        return transformer_wrapper(model, n_components)
    else:
        raise ValueError
