from collections import defaultdict
import torch
from torch import nn
from torch.nn import functional as F


def get_conv(in_dim, out_dim, kernel_size, stride, padding, zero_bias=True, zero_weights=False, groups=1, scaled=False):
    c = nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, groups=groups)
    if zero_bias:
        c.bias.data *= 0.0
    if zero_weights:
        c.weight.data *= 0.0
    return c


def get_3x3(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1, scaled=False):
    return get_conv(in_dim, out_dim, 3, 1, 1, zero_bias, zero_weights, groups=groups, scaled=scaled)


def get_1x1(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1, scaled=False):
    return get_conv(in_dim, out_dim, 1, 1, 0, zero_bias, zero_weights, groups=groups, scaled=scaled)


class Block(nn.Module):
    def __init__(self, in_width, middle_width, out_width, down_rate=None, residual=False, use_3x3=True, zero_last=False, pool=True, diff=False):
        super().__init__()
        self.down_rate = down_rate
        self.residual = residual
        self.c1 = get_1x1(in_width, middle_width)
        self.c2 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width)
        self.c3 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width)
        self.c4 = get_1x1(middle_width, out_width, zero_weights=zero_last)
        self.pool = (pool and down_rate)
        self.diff = (diff and down_rate)
        if not self.pool:
            self.downsample = nn.ConvTranspose2d(out_width, out_width, 4, stride=2, padding=1)

    def forward(self, x):
        xhat = self.c1(F.gelu(x))
        xhat = self.c2(F.gelu(xhat))
        xhat = self.c3(F.gelu(xhat))
        xhat = self.c4(F.gelu(xhat))
        out = x + xhat if self.residual else xhat
        if self.down_rate is None:
            return out
        elif not self.down_rate:
            return [out]
        else:
            if self.pool:
                out_pooled = F.avg_pool2d(out, kernel_size=self.down_rate, stride=self.down_rate)
                if self.diff:
                    out_mean = F.upsample(out_pooled, scale_factor=2, mode='nearest')
                    out_diff = out - out_mean
                    return out_pooled, out_diff
                else:
                    return out_pooled, out
            else:
                out_conv = self.downsample(out)
                return out_conv, out


class TopDownBlockBase(nn.Module):
    def __init__(self, cfgs):
        super(TopDownBlockBase, self).__init__()
        self.width = cfgs.width
        self.use_3x3 = True
        self.cond_width = int(self.width * cfgs.bottleneck_multiple)
        self.zdim = cfgs.zdim
        
    # def sample(self):
    #     raise NotImplementedError()

    # def sample_uncond(self):
    #     raise NotImplementedError()

    # def forward(self):
    #     raise NotImplementedError()

    # def forward_uncond(self):
    #     raise NotImplementedError()
    
    def get_param_q(self, log_var, quantizer, log_param_q_scalar, param_var_q):
        if param_var_q == "vmf":
            param_q = log_param_q_scalar.exp()
        else:
            if param_var_q == "gaussian_1":
                param_q = log_param_q_scalar.exp()
            else:
                if param_var_q == "gaussian_2":
                    log_var_q = log_var.mean(dim=(1,2,3), keepdim=True)
                elif param_var_q == "gaussian_3":
                    log_var_q = log_var.mean(dim=1, keepdim=True)
                elif param_var_q == "gaussian_4":
                    log_var_q = log_var
                else:
                    raise Exception("Undefined param_var_q")
                param_q = (log_var_q.exp() + log_param_q_scalar.exp())
        return param_q


def parse_bottom_up_layer_string(s):
    layers = []
    for ss in s.split(','):
        if 'x' in ss:
            res, num = ss.split('x')
            count = int(num)
            layers += [(int(res), False) for _ in range(count)]
        elif 'd' in ss:
            res, down_rate = [int(a) for a in ss.split('d')]
            layers.append((res, down_rate))
        else:
            res = int(ss)
            layers.append((res, False))
    return layers


def parse_top_down_layer_string(s):
    info = []
    for ss in s.split(','):
        res_str, num_str = ss.split('x')
        res = int(res_str)
        num_str_before, num_str_after = num_str.split('+')
        num_before, num_after = int(num_str_before), int(num_str_after)
        info.append((res, num_before, num_after))
    return info




def parse_bu_layer_string(s):
    layers = []
    for ss in s.split(','):
        if 'x' in ss:
            res, num = ss.split('x')
            count = int(num)
            layers += [(int(res), False) for _ in range(count)]
        elif 'd' in ss:
            res, down_rate = [int(a) for a in ss.split('d')]
            layers.append((res, down_rate))
        else:
            res = int(ss)
            layers.append((res, False))
    return layers


def parse_td_layer_string(s):
    layers = []
    for ss in s.split(','):
        if 'x' in ss:
            res, num = ss.split('x')
            count = int(num)
            layers += [(int(res), False) for _ in range(count)]
        elif 'u' in ss:
            res, down_rate = [int(a) for a in ss.split('u')]
            layers.append((res, down_rate))
        else:
            res = int(ss)
            layers.append((res, False))
    return layers


def parse_sq_layer_string(s):
    layers = []
    for ss in s.split(','):
        if 'x' in ss:
            res, num = ss.split('x')
            count = int(num)
            layers += [(int(res), False) for _ in range(count)]
        elif 'u' in ss:
            res, down_rate = [int(a) for a in ss.split('u')]
            layers.append((res, down_rate))
        else:
            res = int(ss)
            layers.append((res, False))
    return layers
