from collections import defaultdict

import torch
import torch.nn.functional as F
from torch import nn

import operations

from . import N_OPS


class DartsMixedOp(nn.Module):
    def __init__(self, filters, group, stride=1):
        super().__init__()
        self.filters = filters
        self.group = group

        filters = filters * group

        self.red = stride == 2

        self.reg_ops = nn.ModuleList(
            [
                nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
                nn.MaxPool2d(3, stride=stride, padding=1),
                (
                    operations.Identity()
                    if stride == 1
                    else operations.FactorizedReduce(filters, filters, affine=False)
                ),
            ]
        )
        self.conv_ops_1 = nn.ModuleList(
            [
                # sep 3x3
                nn.Conv2d(
                    filters,
                    filters,
                    kernel_size=3,
                    padding=1,
                    groups=filters,
                    bias=False,
                    stride=stride,
                ),
                # sep 5x5
                nn.Conv2d(
                    filters,
                    filters,
                    kernel_size=5,
                    padding=2,
                    groups=filters,
                    bias=False,
                    stride=stride,
                ),
            ]
        )
        self.conv_1_epilogue = nn.Sequential(
            nn.Conv2d(filters * 2, filters * 2, kernel_size=1, groups=group * 2),
            nn.BatchNorm2d(filters * 2, affine=False),
            nn.ReLU(inplace=False),
        )
        self.conv_ops_2 = nn.ModuleList(
            [
                # sep 3x3
                nn.Conv2d(
                    filters,
                    filters,
                    kernel_size=3,
                    padding=1,
                    groups=filters,
                    bias=False,
                ),
                # sep 5x5
                nn.Conv2d(
                    filters,
                    filters,
                    kernel_size=5,
                    padding=2,
                    groups=filters,
                    bias=False,
                ),
                # dil 3x3
                nn.Conv2d(
                    filters,
                    filters,
                    kernel_size=3,
                    padding=2,
                    dilation=2,
                    groups=filters,
                    bias=False,
                    stride=stride,
                ),
                # dil 5x5
                nn.Conv2d(
                    filters,
                    filters,
                    kernel_size=5,
                    padding=4,
                    dilation=2,
                    groups=filters,
                    bias=False,
                    stride=stride,
                ),
            ]
        )
        self.conv_2_epilogue = nn.Sequential(
            nn.Conv2d(filters * 4, filters * 4, kernel_size=1, groups=group * 4),
            nn.BatchNorm2d(filters * 4, affine=False),
        )

    def forward(self, x):
        # apply regular ops
        x0_fmaps = [op(x) for op in self.reg_ops]
        x0 = torch.cat(x0_fmaps, dim=1)

        # apply first relu
        xrelu = F.relu(x)

        # apply conv ops, stage 1
        x1 = torch.cat([op(xrelu) for op in self.conv_ops_1], dim=1)
        x1 = self.conv_1_epilogue(x1)
        x1 = torch.chunk(x1, 2, dim=1)

        # apply conv ops, stage 2
        x2 = [*x1, xrelu, xrelu]
        x2 = torch.cat([op(xi) for xi, op in zip(x2, self.conv_ops_2)], dim=1)
        x2 = self.conv_2_epilogue(x2)

        # concatenate all ops
        x = torch.cat([x0, x2], dim=1)

        # prepare x shape and contract by coeffs
        shape = x.shape
        x = x.reshape([shape[0], N_OPS * self.group, self.filters, *shape[2:]])

        return x


class DartsSearchCell(nn.Module):
    """DARTS cell, search stage."""

    def __init__(
        self,
        steps,
        pp_filters,
        p_filters,
        filters,
        filter_multiplier,
    ):
        super().__init__()

        self.steps = steps
        self.filter_multiplier = filter_multiplier

        self.pp_filters = pp_filters
        self.p_filters = p_filters
        self.filters = filters

        self.pre_pp = operations.ConvReLUBN(
            pp_filters, filter_multiplier, 1, 0, affine=False
        )
        self.pre_p = operations.ConvReLUBN(
            p_filters, filter_multiplier, 1, 0, affine=False
        )

        self.n_red = int(pp_filters != filters) + int(p_filters != filters)

        self.r_ops = nn.ModuleList()
        self.ops = nn.ModuleList()

        for i in range(self.steps):
            if self.n_red != 0:
                self.r_ops.append(DartsMixedOp(filter_multiplier, self.n_red, stride=2))
            if i + 2 - self.n_red != 0:
                self.ops.append(
                    DartsMixedOp(
                        filter_multiplier,
                        i + 2 - self.n_red,
                    )
                )
            else:
                self.ops.append(None)

    def forward(self, pp, p, genotype, collect_stats):
        pp_state, p_state = self.pre_pp(pp), self.pre_p(p)
        att_states = [pp_state, p_state]

        states = [pp_state, p_state]
        r_states = states[: self.n_red]
        states = states[self.n_red :]

        if len(r_states) > 0:
            r_states = torch.cat(r_states, dim=1)

        if len(states) > 0:
            states = torch.cat(states, dim=1)
        else:
            states = None

        # type, source, target, op_idx
        coeffs_dict = defaultdict(list)
        for i in range(self.steps):
            s = rs = None

            if self.n_red != 0:
                rs = self.r_ops[i](r_states)
            if i + 2 - self.n_red != 0:
                s = self.ops[i](states)

            if rs is None:
                x = s
            elif s is None:
                x = rs
            else:
                x = torch.cat([rs, s], dim=1)

            coeffs = genotype.compute_coeffs(
                k_states=x,
                block_id=i,
                att_states=att_states,
            )

            if collect_stats:
                B, P = coeffs.shape
                for bi in range(B):
                    for pi in range(P):
                        coeffs_dict["source"].append(pi % (i + 2))
                        coeffs_dict["target"].append(i)
                        coeffs_dict["op_idx"].append(pi // (i + 2))
                        coeffs_dict["coeff"].append(
                            coeffs[bi, pi].mean().half().detach().item()
                        )
                        coeffs_dict["sample_id"].append(bi)

            coeffs = coeffs.reshape(*coeffs.shape, 1, 1, 1)
            s = torch.sum(x * coeffs, dim=1)

            att_states.append(s)

            if states is not None:
                states = torch.cat([states, s], dim=1)
            else:
                states = s

        return (
            states[:, (2 - self.n_red) * self.filter_multiplier :, ...],
            coeffs_dict,
        )


class NasBenchMixedOp(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.filters = filters

        self.ops = nn.ModuleList(
            [
                operations.Zero(),
                nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
                operations.Identity(),
                operations.ConvReLUBN(
                    filters,
                    filters,
                    kernel_size=3,
                    padding=1,
                    dilation=1,
                    affine=False,
                    stride=1,
                    track_running_stats=False,
                ),
                operations.ConvReLUBN(
                    filters,
                    filters,
                    kernel_size=1,
                    padding=0,
                    dilation=1,
                    affine=False,
                    stride=1,
                    track_running_stats=False,
                ),
            ]
        )

    def forward(self, xs):
        out = []
        for op in self.ops:
            for x in xs:
                out.append(op(x))

        out = [x.unsqueeze(1) for x in out]
        out = torch.cat(out, dim=1)

        return out


class NasBenchSearchCell(nn.Module):
    """NAS-Bench-201 cell, search stage."""

    def __init__(
        self,
        steps,
        filters,
    ):
        super().__init__()

        self.steps = steps
        self.filters = filters

        self.ops = nn.ModuleList()
        for _ in range(self.steps):
            self.ops.append(
                NasBenchMixedOp(
                    filters,
                )
            )

    def forward(self, p, genotype, collect_stats):
        states = [p]

        # type, source, target, op_idx
        coeffs_dict = defaultdict(list)
        for i in range(self.steps):
            x = self.ops[i](states)

            coeffs = genotype.compute_coeffs(
                k_states=x,
                block_id=i,
                att_states=states,
            )

            if collect_stats:
                B, P = coeffs.shape
                for bi in range(B):
                    for pi in range(P):
                        coeffs_dict["source"].append(pi % (i + 1))
                        coeffs_dict["target"].append(i)
                        coeffs_dict["op_idx"].append(pi // (i + 1))
                        coeffs_dict["coeff"].append(
                            coeffs[bi, pi].mean().half().detach().item()
                        )
                        coeffs_dict["sample_id"].append(bi)

            coeffs = coeffs.reshape(*coeffs.shape, 1, 1, 1)
            s = torch.sum(x * coeffs, dim=1)

            states.append(s)

        return (
            states[-1],
            coeffs_dict,
        )
