import torch
import torch.nn as nn
import torch.nn.functional as F

from genotypes_201 import PRIMITIVES, Genotype
from operations_201 import OPS, ReLUConvBN, FactorizedReduce

EDGE_INDEX_NB201 = [
    (0, 1),
    (0, 2), (1, 2),
    (0, 3), (1, 3), (2, 3),
]

class MixedOp(nn.Module):
    def __init__(self, C: int):
        super().__init__()
        self._ops = nn.ModuleList()
        for primitive in PRIMITIVES:
            op = OPS[primitive](C, 1, False)
            self._ops.append(op)

    def forward(self, x: torch.Tensor, weights_row: torch.Tensor) -> torch.Tensor:
        return sum(w * op(x) for w, op in zip(weights_row, self._ops))

class Cell(nn.Module):
    def __init__(self, C_in: int, C: int, reduction: bool):
        super().__init__()
        self.reduction = reduction
        self.C = C
        if reduction:
            self.preprocess = FactorizedReduce(C_in, C, affine=False)
        else:
            self.preprocess = ReLUConvBN(C_in, C, 1, 1, 0, affine=False)
        self.edges = EDGE_INDEX_NB201[:]
        self.num_edges = len(self.edges)
        self.num_ops = len(PRIMITIVES)
        self._ops = nn.ModuleList([MixedOp(C) for _ in range(self.num_edges)])
        self._multiplier = 3

    def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
        s0 = self.preprocess(x)
        states = [s0, None, None, None]
        e = 0
        states[1] = self._ops[e](states[0], weights[e]); e += 1
        out_02 = self._ops[e](states[0], weights[e]); e += 1
        out_12 = self._ops[e](states[1], weights[e]); e += 1
        states[2] = out_02 + out_12
        out_03 = self._ops[e](states[0], weights[e]); e += 1
        out_13 = self._ops[e](states[1], weights[e]); e += 1
        out_23 = self._ops[e](states[2], weights[e]); e += 1
        states[3] = out_03 + out_13 + out_23
        return torch.cat([states[1], states[2], states[3]], dim=1)

    @property
    def out_channels(self) -> int:
        return 3 * self.C

class Network(nn.Module):
    def __init__(self, C: int, num_classes: int, layers: int, criterion, stem_multiplier: int = 1):
        super().__init__()
        self._C = C
        self._num_classes = num_classes
        self._layers = layers
        self._criterion = criterion
        C_stem = C * stem_multiplier
        self.stem = nn.Sequential(
            nn.Conv2d(3, C_stem, 3, padding=1, bias=False),
            nn.BatchNorm2d(C_stem)
        )
        self.cells = nn.ModuleList()
        C_in = C_stem
        C_curr = C
        for i in range(layers):
            reduction = (i in [layers // 3, 2 * layers // 3])
            if reduction:
                C_curr *= 2
            cell = Cell(C_in, C_curr, reduction=reduction)
            self.cells.append(cell)
            C_in = cell.out_channels
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C_in, num_classes)
        self._initialize_alphas()

    def _initialize_alphas(self):
        self.num_edges = 6
        self.num_ops = len(PRIMITIVES)
        self.alphas = nn.Parameter(1e-3 * torch.randn(self.num_edges, self.num_ops), requires_grad=False)
        self._arch_params = [self.alphas]

    def arch_parameters(self):
        return self._arch_params

    @torch.no_grad()
    def get_alphas(self) -> torch.Tensor:
        return self.alphas

    @torch.no_grad()
    def get_weights(self) -> torch.Tensor:
        return F.softmax(self.alphas, dim=-1)

    def get_projected_weights(self) -> torch.Tensor:
        return self.get_weights()

    def forward(self, x: torch.Tensor, weights_dict: torch.Tensor = None) -> torch.Tensor:
        weights = self.get_projected_weights() if (weights_dict is None) else weights_dict
        s = self.stem(x)
        for cell in self.cells:
            s = cell(s, weights)
        out = self.global_pooling(s)
        logits = self.classifier(out.view(out.size(0), -1))
        return logits

    def _loss(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return self._criterion(self(x), y)

    @torch.no_grad()
    def genotype(self, ban_none: bool = False) -> Genotype:
        probs = F.softmax(self.alphas, dim=-1)
        idx_none = PRIMITIVES.index('none')
        gene = []
        for e, (u, v) in enumerate(EDGE_INDEX_NB201):
            row = probs[e].clone()
            if ban_none:
                row[idx_none] = float('-inf')
            k = int(row.argmax().item())
            op_name = PRIMITIVES[k]
            gene.append((op_name, u))
        return Genotype(normal=gene, normal_concat=[1, 2, 3])

    @torch.no_grad()
    def export_nb201_string(self, ban_none: bool = True) -> str:
        probs = F.softmax(self.alphas, dim=-1)
        idx_none = PRIMITIVES.index('none')
        best = {}
        for e, (u, v) in enumerate(EDGE_INDEX_NB201):
            row = probs[e].clone()
            if ban_none:
                row[idx_none] = float('-inf')
            k = int(row.argmax().item())
            best[(u, v)] = PRIMITIVES[k]
        op_map = {
            'conv_1x1':     'nor_conv_1x1',
            'conv_3x3':     'nor_conv_3x3',
            'avg_pool_3x3': 'avg_pool_3x3',
            'skip_connect': 'skip_connect',
            'none':         'none',
        }
        s1 = f"|{op_map[best[(0, 1)]]}~0|"
        s2 = f"|{op_map[best[(0, 2)]]}~0|{op_map[best[(1, 2)]]}~1|"
        s3 = f"|{op_map[best[(0, 3)]]}~0|{op_map[best[(1, 3)]]}~1|{op_map[best[(2, 3)]]}~2|"
        return "+".join([s1, s2, s3])

    @torch.no_grad()
    def get_edges(self):
        return EDGE_INDEX_NB201

    @torch.no_grad()
    def get_primitives(self):
        return PRIMITIVES
