import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from ..ops.reduced_dynamic_layers_noskip import NaMixedOp, ScMixedOp, LaMixedOp
from ..ops.reduced_dynamic_layers_noskip import NA_PRIMITIVES, SC_PRIMITIVES, LA_PRIMITIVES
from ..backbone import ReducedSANEBackBone

class ReducedSANESearchSpace(nn.Module):
    '''
        implement this for sane.
        Actually, sane can be seen as the combination of three cells, node aggregator, skip connection, and layer aggregator
        for sane, we dont need cell, since the DAG is the whole search space, and what we need to do is implement the DAG.
    '''

    def __init__(self, in_dim, out_dim, hidden_size, num_layers=3, dropout=0.5, epsilon=0.0, with_conv_linear=False, config=None, with_bn=False):
        super(ReducedSANESearchSpace, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout=dropout
        self.epsilon = epsilon
        self.explore_num = 0
        self.with_linear = with_conv_linear
        self.config = config

        #node aggregator op
        self.lin1 = nn.Linear(in_dim, hidden_size)
        self.layer1 = NaMixedOp(hidden_size, hidden_size,self.with_linear)
        self.layer2 = NaMixedOp(hidden_size, hidden_size,self.with_linear)
        self.layer3 = NaMixedOp(hidden_size, hidden_size,self.with_linear)

        #skip op
        self.layer4 = ScMixedOp()
        self.layer5 = ScMixedOp()
        if not getattr(self.config, 'fix_last', False):
            self.layer6 = ScMixedOp()

        #layer aggregator op
        self.layer7 = LaMixedOp(hidden_size, num_layers)
        self.classifier = nn.Linear(hidden_size, out_dim)
        self._initialize_alphas()

    def new(self):
        model_new = ReducedSANESearchSpace(self.in_dim, self.out_dim, self.hidden_size)
        for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
            x.data.copy_(y.data)
        return model_new

    def forward(self, data, discrete=False):
        x, edge_index = data.x, data.edge_index
        #prob = float(np.random.choice(range(1,11), 1) / 10.0)

        self.na_weights = F.softmax(self.na_alphas, dim=-1)
        self.sc_weights = F.softmax(self.sc_alphas, dim=-1)
        self.la_weights = F.softmax(self.la_alphas, dim=-1)

        #generate weights by softmax
        x = self.lin1(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x1 = self.layer1(x, self.na_weights[0], edge_index)
        x1 = F.dropout(x1, p=self.dropout, training=self.training)
        x2 = self.layer2(x1, self.na_weights[1], edge_index)
        x2 = F.dropout(x2, p=self.dropout, training=self.training)
        x3 = self.layer3(x2, self.na_weights[2], edge_index)
        x3 = F.dropout(x3, p=self.dropout, training=self.training)

        if getattr(self.config, 'fix_last', False):
            x4 = (x3, self.layer4(x1, self.sc_weights[0]), self.layer5(x2, self.sc_weights[1]))
        else:
            x4 = (self.layer4(x1, self.sc_weights[0]), self.layer5(x2, self.sc_weights[1]), self.layer6(x3, self.sc_weights[2]))

        x5 = self.layer7(x4, self.la_weights[0])
        x5 = F.dropout(x5, p=self.dropout, training=self.training)

        logits = self.classifier(x5)
        return logits

    def _initialize_alphas(self):
        #k = sum(1 for i in range(self._steps) for n in range(2+i))
        num_na_ops = len(NA_PRIMITIVES)
        num_sc_ops = len(SC_PRIMITIVES)
        num_la_ops = len(LA_PRIMITIVES)

        #self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops), requires_grad=True)
        self.na_alphas = Variable(1e-3*torch.randn(3, num_na_ops), requires_grad=True)
        if getattr(self.config, 'fix_last', False):
            self.sc_alphas = Variable(1e-3*torch.randn(2, num_sc_ops), requires_grad=True)
        else:
            self.sc_alphas = Variable(1e-3*torch.randn(3, num_sc_ops), requires_grad=True)

        self.la_alphas = Variable(1e-3*torch.randn(1, num_la_ops), requires_grad=True)
        self._arch_parameters = [
            self.na_alphas,
            self.sc_alphas,
            self.la_alphas,
        ]

    def arch_parameters(self):
        return self._arch_parameters

    def get_prob_result(self):
        result = self.genotype() + '\n'
        result += str(NA_PRIMITIVES) + '\n'
        result += str(self.na_alphas) + '\n'
        result += str(SC_PRIMITIVES) + '\n'
        result += str(self.sc_alphas) + '\n'
        result += str(LA_PRIMITIVES) + '\n'
        result += str(self.la_alphas) + '\n'
        return result

    def genotype(self):
        def _parse(na_weights, sc_weights, la_weights):
            gene = []
            na_indices = torch.argmax(na_weights, dim=-1)
            for k in na_indices:
                gene.append(NA_PRIMITIVES[k])
            #sc_indices = sc_weights.argmax(dim=-1)
            sc_indices = torch.argmax(sc_weights, dim=-1)
            for k in sc_indices:
                gene.append(SC_PRIMITIVES[k])
            #la_indices = la_weights.argmax(dim=-1)
            la_indices = torch.argmax(la_weights, dim=-1)
            for k in la_indices:
                gene.append(LA_PRIMITIVES[k])
            return '||'.join(gene)

        gene = _parse(F.softmax(self.na_alphas, dim=-1).data.cpu(), F.softmax(self.sc_alphas, dim=-1).data.cpu(), F.softmax(self.la_alphas, dim=-1).data.cpu())
        return gene

    def sample_genotype(self):
        gene = []
        for _ in range(3):
            op = np.random.choice(NA_PRIMITIVES, 1)[0]
            gene.append(op)
        for _ in range(2):
            op = np.random.choice(SC_PRIMITIVES, 1)[0]
            gene.append(op)
        op = np.random.choice(LA_PRIMITIVES, 1)[0]
        gene.append(op)
        return '||'.join(gene)

    def get_weights_from_arch(self, arch):
        arch_ops = arch.split('||')
        #print('arch=%s' % arch)
        num_na_ops = len(NA_PRIMITIVES)
        num_sc_ops = len(SC_PRIMITIVES)
        num_la_ops = len(LA_PRIMITIVES)

        na_alphas = Variable(torch.zeros(3, num_na_ops), requires_grad=True)
        sc_alphas = Variable(torch.zeros(2, num_sc_ops), requires_grad=True)
        la_alphas = Variable(torch.zeros(1, num_la_ops), requires_grad=True)

        for i in range(3):
            ind = NA_PRIMITIVES.index(arch_ops[i])
            na_alphas[i][ind] = 1

        for i in range(3, 5):
            ind = SC_PRIMITIVES.index(arch_ops[i])
            sc_alphas[i-3][ind] = 1

        ind = LA_PRIMITIVES.index(arch_ops[5])
        la_alphas[0][ind] = 1

        arch_parameters = [na_alphas, sc_alphas, la_alphas]
        return arch_parameters

    def set_model_weights(self, weights):
        self.na_weights = weights[0]
        self.sc_weights = weights[1]
        self.la_weights = weights[2]
        #self._arch_parameters = [self.na_alphas, self.sc_alphas, self.la_alphas]

    def sample_active_subnet(self, sample_mode='random', subnet_settings=None):
        # In the forward pass, we need to clarify the sample mode ['random', 'subnet']
        if sample_mode == 'random':
            genotype = self.sample_genotype()
            subnet_settings.genotype = genotype
            return subnet_settings

    def build_active_subnet(self, subnet_settings):
        subnet = ReducedSANEBackBone(subnet_settings['genotype'],
                                self.in_dim, self.out_dim, subnet_settings['hidden_size'], 
                                self.num_layers, in_dropout=subnet_settings['in_dropout'],
                                out_dropout=subnet_settings['out_dropout'],
                                act=subnet_settings['act'],
                                config=subnet_settings['config'])
        return subnet

def sane_search_space(**kwargs):
    return ReducedSANESearchSpace(**kwargs)

if __name__ == '__main__':
    pass