import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers.evo_norm import EvoNormSample2d

class WSConv(nn.Conv2d):
    def __init__(self, ic, oc, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super().__init__(ic, oc, kernel_size, stride, padding, dilation, groups, bias)
    
    def forward(self, x):
        # (oc, ic, k, k)
        weight = self.weight

        # (oc, ic, k, k)
        weight = weight - weight.mean((1, 2, 3), keepdims=True)

        # (oc, 1, 1, 1)
        oc = weight.shape[0]
        weight_std = weight.reshape(oc, -1).std(dim=1) + 1e-5
        weight_std = weight_std.reshape(oc, 1, 1, 1)

        # (oc, ic, k, k)
        weight = weight / weight_std

        out = F.conv2d(x, weight, self.bias, self.stride, self.padding,
                       self.dilation, self.groups)
        return out

class ConvBlk(nn.Module):
    def __init__(self, ic, oc, norm_type='ws_gn'):
        super().__init__()
        if norm_type == 'ws_gn':
            self.m = nn.Sequential(
                WSConv(ic, oc, 3, 1, 1),
                nn.GroupNorm(32, oc))
            self.act = nn.ReLU()
        elif norm_type == 'bn_relu':
            self.m = nn.Sequential(
                nn.Conv2d(ic, oc, 3, 1, 1),
                nn.BatchNorm2d(oc))
            self.act = nn.ReLU()
        elif norm_type == 'evonorm_s0':
            self.m = nn.Sequential(
                nn.Conv2d(ic, oc, 3, 1, 1),
                EvoNormSample2d(oc))
            self.act = nn.Identity()
    
    def forward(self, x):
        o = self.m(x)
        o = self.act(o)
        return o

class UpConvBlk(nn.Module):
    def __init__(self, ic, ic_2, oc, norm_type='ws_gn'):
        super().__init__()
        self.up = nn.ConvTranspose2d(ic, ic, 2, 2)
        self.act = nn.ReLU()
        if norm_type == 'ws_gn':
            self.m = nn.Sequential(
                WSConv(ic+ic_2, oc, 3, 1, 1),
                nn.GroupNorm(32, oc),
                nn.ReLU())
        elif norm_type == 'bn_relu':
            self.m = nn.Sequential(
                nn.Conv2d(ic+ic_2, oc, 3, 1, 1),
                nn.BatchNorm2d(oc),
                nn.ReLU())
        elif norm_type == 'evonorm_s0':
            self.m = nn.Sequential(
                nn.Conv2d(ic+ic_2, oc, 3, 1, 1),
                EvoNormSample2d(oc))
        
    def forward(self, x, x_2):
        o = self.up(x)
        o = self.act(o)
        o = torch.cat([o, x_2], 1)
        return self.m(o)

class UpConvBlkCond(nn.Module):
    def __init__(self, ic, ic_2, oc, norm_type='ws_gn'):
        super().__init__()
        self.up = nn.ConvTranspose2d(ic, ic, 2, 2)
        self.act = nn.ReLU()
        if norm_type == 'ws_gn':
            self.m = nn.Sequential(
                WSConv(ic+ic_2, oc, 3, 1, 1),
                nn.GroupNorm(32, oc),
                nn.ReLU())
        elif norm_type == 'bn_relu':
            self.m = nn.Sequential(
                nn.Conv2d(ic+ic_2, oc, 3, 1, 1),
                nn.BatchNorm2d(oc),
                nn.ReLU())
        elif norm_type == 'evonorm_s0':
            self.m = nn.Sequential(
                nn.Conv2d(ic+ic_2, oc, 3, 1, 1),
                EvoNormSample2d(oc))
        self.film = FiLM(1, ic)
    
    def forward(self, x, x_2, x_cond):
        o = self.up(x)
        o = self.film(o, x_cond)
        o = self.act(o)
        o = torch.cat([o, x_2], 1)
        return self.m(o)

class FiLM(nn.Module):
    def __init__(self, ic, oc, hc=128):
        super().__init__()
        self.oc = oc
        self.m = nn.Sequential(
            nn.Linear(ic, hc),
            nn.ReLU(),
            nn.Linear(hc, oc*2))
    
    def forward(self, x, cond):
        # (bs, oc x 2)
        cond = self.m(cond)
        gamma = cond[:, :self.oc, None, None]
        beta = cond[:, self.oc:, None, None]
        out = x * gamma + beta
        return out

class ConvBlkCond(nn.Module):
    def __init__(self, ic, oc, norm_type='ws_gn'):
        super().__init__()
        if norm_type == 'ws_gn':
            self.m = nn.Sequential(
                WSConv(ic, oc, 3, 1, 1),
                nn.GroupNorm(32, oc))
            self.act = nn.ReLU()
        elif norm_type == 'bn_relu':
            self.m = nn.Sequential(
                nn.Conv2d(ic, oc, 3, 1, 1),
                nn.BatchNorm2d(oc))
            self.act = nn.ReLU()
        elif norm_type == 'evonorm_s0':
            self.m = nn.Sequential(
                nn.Conv2d(ic, oc, 3, 1, 1),
                EvoNormSample2d(oc))
            self.act = nn.Identity()
        self.film = FiLM(1, oc)
    
    def forward(self, x, x_cond):
        o = self.m(x)
        o = self.film(o, x_cond)
        o = self.act(o)
        return o