import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
from operations import *
from genotypes import PRIMITIVES
from genotypes import Genotype
from utils import count_parameters


class MixedOp(nn.Module):
    def __init__(self, C, stride):
        super(MixedOp, self).__init__()
        self.ops = nn.ModuleList()

        for primitive in PRIMITIVES:
            op = OPS[primitive](C, stride, False)
            # if "pool" in primitive:
            #     op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
            self.ops.append(op)

    def forward(self, x, weights):
        self.count_params = sum(w * count_parameters(op) for w, op in zip(weights, self.ops))
        return sum(w * op(x) for w, op in zip(weights, self.ops))


class MixedInpOp(nn.Module):
    def __init__(self, C, stride):
        super(MixedInpOp, self).__init__()
        self.mixed_op = MixedOp(C, stride)

    def forward(self, inputs, input_weights, weights):
        inputs = [w * t for w, t in zip(input_weights, inputs)]
        input_to_mixed_op = sum(inputs)
        output = self.mixed_op(input_to_mixed_op, weights=weights)
        self.count_params = self.mixed_op.count_params
        return output
    
class Cell(nn.Module):
    def __init__(self, steps, C_prev_prev, C_prev, C, reduction, reduction_prev):
        super(Cell, self).__init__()
        self.reduction = reduction
        self.steps = steps
        self.ops_per_step = 2

        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
        
        if reduction:
            self.preprocess0 = nn.Sequential(
                self.preprocess0, 
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            )
            self.preprocess1 = nn.Sequential(
                self.preprocess1, 
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            )
            
        self.ops = nn.ModuleList()

        for i in range(self.steps * self.ops_per_step):
            op = MixedInpOp(C, stride=1)
            self.ops.append(op)

    def forward(self, s0, s1, weights, output_weights, input_weights):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)
        
        count_params = [
            count_parameters(self.preprocess0),
            count_parameters(self.preprocess1)
        ]

        states = [s0, s1]
        for i in range(self.steps):
            # following nodes will accept all proceeding nodes as inputs
            s = sum([
                self.ops[2*i](states, input_weights[i][0], weights[2*i]), 
                self.ops[2*i+1](states, input_weights[i][1], weights[2*i+1])
            ])
            states.append(s)
            
        if output_weights is not None:
            out_states = [w * t for w, t in zip(output_weights, states[2:])]
        else:
            out_states = states[2:]
            
        count_params += [op.count_params for op in self.ops]
        self.count_params = sum(count_params)
        return torch.cat(out_states, dim=1)


class Network(nn.Module):
    def __init__(self,C,num_classes,layers,steps=4,multiplier=4,stem_multiplier=3, init_alphas=0.01, gumbel=False, out_weight=False):
        super(Network, self).__init__()
        self._C = C
        self._num_classes = num_classes
        self._layers = layers
        self._steps = steps
        self._multiplier = multiplier
        self._init_alphas = init_alphas
        self._gumbel = gumbel
        self._tau = 1
        self._out_weight = out_weight

        C_curr = stem_multiplier * C
        self.stem = nn.Sequential(
            nn.Conv2d(3, C_curr, 3, padding=1, bias=False), 
            nn.BatchNorm2d(C_curr, affine=False)
        )

        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
        self.cells = nn.ModuleList()
        reduction_prev = False
        for i in range(layers):
            if i in [layers // 3, 2 * layers // 3]:
                C_curr *= 2
                reduction = True
            else:
                reduction = False
            cell = Cell(steps,C_prev_prev,C_prev,C_curr,reduction,reduction_prev)
            reduction_prev = reduction
            self.cells += [cell]
            C_prev_prev, C_prev = C_prev, multiplier * C_curr
            
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C_prev, num_classes)

        self.initialize_alphas()

    def forward(self, input):
        s0 = s1 = self.stem(input)
        count_params = []
            
        for i, cell in enumerate(self.cells):
            
            normal_weights = self.sample(reduction=False)
            reduce_weights = self.sample(reduction=True)
        
            if not cell.reduction:
                weights, output_weights, new_input_weights = normal_weights
            else:
                weights, output_weights, new_input_weights = reduce_weights
            s0, s1 = s1, cell(s0, s1, weights, output_weights, new_input_weights)
            
            count_params += [cell.count_params]
        
        out = self.global_pooling(s1)
        logits = self.classifier(out.view(out.size(0), -1))
        
        self.count_params = sum(count_params)
        
        return logits

    def initialize_alphas(self):
        NUM_CELL_INPUTS = 2
        TYPES_OF_CELL = 2
        OPS_PER_STEP = 2
        
        self._arch_parameters = []
        for _ in range(TYPES_OF_CELL):
            self.alphas_mixed_op = Variable(
                self._init_alphas * torch.rand(self._steps * OPS_PER_STEP, len(PRIMITIVES)).cuda(),
                requires_grad=True,
            )

            self.alphas_output = Variable(
                self._init_alphas * torch.rand(self._steps).cuda(),
                requires_grad=True,
            )

            self.alphas_inputs = [
                Variable(
                    self._init_alphas * torch.rand(OPS_PER_STEP, n_inputs).cuda(), 
                    requires_grad=True
                )
                for n_inputs in range(NUM_CELL_INPUTS,  self._steps + NUM_CELL_INPUTS)
            ]

            self._arch_parameters += [self.alphas_mixed_op, self.alphas_output, *self.alphas_inputs]
        self._cell_params_length = len(self._arch_parameters) // TYPES_OF_CELL
        
        
    def get_cell_arch_params(self, start_idx, alphas=None):
        if alphas is None:
            alphas = self._arch_parameters
        weights = alphas[start_idx]
        output_weights = alphas[start_idx+1]
        input_weights = alphas[start_idx+2:]
        return weights, output_weights, input_weights
    
    
    def sample(self, reduction=False):
        start_idx = self._cell_params_length if reduction else 0
        weights, output_weights, input_weights = self.get_cell_arch_params(start_idx)
        weights = self.normalize(weights, self._gumbel)
        output_weights = self.normalize(output_weights, self._gumbel) if self._out_weight else None
        new_input_weights = []
        for j in range(self._steps):
            new_input_weights += [self.normalize(input_weights[j], self._gumbel)]
        return weights, output_weights, new_input_weights
    
    
    def arch_param_grad_norm(self, grads=None):
        norm = 0
        eps = 1e-5
        if grads is None:
            for p in self._arch_parameters:
                if p.grad is not None:
                    norm += (p.grad**2).sum()
            return (norm + eps).sqrt()
        else:
            for g in grads:
                if g.grad is not None:
                    norm += (g**2).sum()
            return (norm + eps).sqrt()


    def reset_zero_grads(self):
        self.zero_grad()
        for p in self._arch_parameters:
            if p.grad is not None:
                p.grad.zero_()

    def normalize(self, x, gumbel=False):
        if gumbel:
            return F.gumbel_softmax(x, dim=-1, hard=True, tau=self._tau)
        else:
            return F.softmax(x, dim=-1)

    def genotype(self, alphas=None):
        def _parse(weights, input_weights, output_weights, num_outputs):
            ops_idx = np.argmax(weights, axis=-1)
            out_idx = np.argsort(output_weights, axis=-1)[-num_outputs:]
            inp_idx = []
            for i in range(self._steps):
                w = input_weights[i]
                op1_row_idx, op1_col_idx = np.unravel_index(np.argmax(w, axis=None), w.shape) # get indix of maximal probability
                # try to select inputs from different nodes for the second operation
                op2_col_idx_rank1, op2_col_idx_rank2 = np.argsort(w[int(1-op1_row_idx)])[-2:]
                if op1_col_idx == op2_col_idx_rank1:
                    op2_col_idx = op2_col_idx_rank2
                else:
                    op2_col_idx = op2_col_idx_rank1
                idx = [op1_col_idx, op2_col_idx] if op1_row_idx == 0 else [op2_col_idx, op1_col_idx]
                inp_idx += idx
                    
            
            gene = []
            for i, op_idx in enumerate(ops_idx):
                gene += [(PRIMITIVES[op_idx], inp_idx[i])]
            return gene, ops_idx, out_idx

        print(alphas[1])
        normal_weights, normal_output_weights, normal_input_weights = self.get_cell_arch_params(0, alphas)
        reduce_weights, reduce_output_weights, reduce_input_weights = self.get_cell_arch_params(self._cell_params_length, alphas)
        
        
        gene_normal, ops_normal, out_normal = _parse(
            weights=normal_weights.data.cpu().numpy(),
            output_weights=normal_output_weights.data.cpu().numpy(),
            input_weights=[x.data.cpu().numpy() for x in normal_input_weights],
            num_outputs=self._multiplier
        )
        
        gene_reduce, ops_reduce, out_reduce = _parse(
            weights=reduce_weights.data.cpu().numpy(),
            output_weights=reduce_output_weights.data.cpu().numpy(),
            input_weights=[x.data.cpu().numpy() for x in reduce_input_weights],
            num_outputs=self._multiplier
        )
        
        if not self._out_weight:
            out_normal = list(range(2 + self._steps - self._multiplier, self._steps + 2))
            out_reduce = list(range(2 + self._steps - self._multiplier, self._steps + 2))
            
        genotype = Genotype(
            normal=gene_normal,
            normal_concat=out_normal,
            reduce=gene_reduce,
            reduce_concat=out_reduce,
        )
        
        return genotype

