import torch
import torch.nn as nn
import math
from copy import deepcopy


class LoRAParameter(nn.Module):
    def __init__(self, weight):
        super(LoRAParameter, self).__init__()
        self.weight = nn.Parameter(torch.tensor(weight))

    def forward(self):
        return self.weight


def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    )


class DownsampleB(nn.Module):

    def __init__(self, stride: int = 2) -> None:
        super().__init__()
        self.max = nn.MaxPool2d(stride)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.max(x)
        zeros = torch.zeros_like(residual)
        return torch.cat((residual, zeros), dim=1)


def depthwise_conv(in_channels: int, kernel_size: int = 3, stride: int = 1,
                   padding: int = 0, dilation: int = 1, bias: bool = False):
    return nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)


def pointwise_conv(
    in_channels: int, out_channels: int, bias: bool = False):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=bias)


class SeparableConv2d(nn.Module):

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 0,
                 dilation: int = 1, bias: bool = False, nb_tasks: int = 10, forward_transfer=False):

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.bias = bias
        self.nb_tasks = nb_tasks
        self.forward_transfer = forward_transfer

        self.scaling = 1

        self.pointwise = nn.ModuleList([pointwise_conv(in_channels, out_channels, bias) for _ in range(nb_tasks)])
        self.depthwise = nn.ModuleList([depthwise_conv(out_channels, kernel_size, stride, padding, dilation, bias=bias)
                                        for _ in range(nb_tasks)])

    def forward(self, x: torch.Tensor, task_id: int, training=False) -> torch.Tensor:
        if training or task_id == 0:
            if task_id == 0:
                x_pt = self.pointwise[0](x)
                x_out = self.depthwise[0](x_pt)
            else:
                x_pt = self.pointwise[-1](x)
                x_out = self.depthwise[-1](x_pt)

        else:
            if self.forward_transfer:
                with torch.no_grad():
                    weight_pt = deepcopy(self.pointwise[0].weight.data)
                    weight_dp = deepcopy(self.depthwise[0].weight.data)

                weight_pt += (self.lora_B_pointwise[task_id-1].weight @ self.lora_A_pointwise[task_id-1].weight).view(weight_pt.shape) * self.scaling
                weight_dp += (self.lora_B_depthwise[task_id-1].weight @ self.lora_A_depthwise[task_id-1].weight).view(weight_dp.shape) * self.scaling

                x_pt = self.pointwise[0]._conv_forward(x, weight_pt, None)
                x_out = self.depthwise[0]._conv_forward(x_pt, weight_dp, None)
            else:
                x_pt = self.pointwise[task_id](x)
                x_out = self.depthwise[task_id](x_pt)

        return x_out

    def reset_parameters(self):
        if hasattr(self, 'lora_A_pointwise'):
            nn.init.kaiming_uniform_(self.lora_A_depthwise[-1].weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B_depthwise[-1].weight)

            nn.init.kaiming_uniform_(self.lora_A_pointwise[-1].weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B_pointwise[-1].weight)


class BasicBlock(nn.Module):
    def __init__(self, in_filters: int, out_filters: int, nb_tasks: int = 10, stride: int = 1, forward_transfer=False):
        super().__init__()
        self.SeparableConv2d1 = SeparableConv2d(in_filters, out_filters, 3, stride=stride, padding=1,
                                                bias=False, nb_tasks=nb_tasks, forward_transfer=forward_transfer)

        self.bns1 = nn.ModuleList([nn.BatchNorm2d(out_filters) for _ in range(nb_tasks)])
        self.relu = nn.ReLU(inplace=True)
        self.SeparableConv2d2 = SeparableConv2d(out_filters, out_filters, 3, stride=1, padding=1,
                                                bias=False, nb_tasks=nb_tasks, forward_transfer=forward_transfer)
        self.bns2 = nn.ModuleList([nn.BatchNorm2d(out_filters) for _ in range(nb_tasks)])

    def forward(self, x: torch.Tensor, task_id: int, training=False) -> torch.Tensor:
        x = self.SeparableConv2d1(x, task_id, training)
        x = self.bns1[task_id](x)
        x = self.relu(x)
        x = self.SeparableConv2d2(x, task_id, training)
        out = self.bns2[task_id](x)

        return out
