import torch
from copy import deepcopy
import torch.nn as nn
from sota.cnn.operations import *
from sota.cnn.genotypes import Genotype
import sys
sys.path.insert(0, '../../')
from sota.cnn.model_search import Network

class ImageNetNetworkProj(Network):
    def __init__(self, C, num_classes, layers, criterion, primitives, args,
                 steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0, nettype='imagenet'):
        super(ImageNetNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
              steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob, nettype=nettype)
        
        self._initialize_flags()
        self._initialize_proj_weights()
        self._initialize_topology_dicts()

    #### proj flags
    def _initialize_topology_dicts(self):
        self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
        self.nid2selected_eids = {
            'normal': {0:[],1:[],2:[]},
            'reduce': {0:[],1:[],2:[]},
        }
    
    def _initialize_flags(self):
        self.candidate_flags = {
            'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
            'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
        } # must be in this order
        self.candidate_flags_edge = {
            'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
            'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
        }

    def _initialize_proj_weights(self):
        ''' data structures used for proj '''
        if isinstance(self.alphas_normal, list):
            alphas_normal = torch.stack(self.alphas_normal, dim=0)
            alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
        else:
            alphas_normal = self.alphas_normal
            alphas_reduce = self.alphas_reduce

        self.proj_weights = { # for hard/soft assignment after project
            'normal': torch.zeros_like(alphas_normal),
            'reduce': torch.zeros_like(alphas_reduce),
        }
    
    #### proj function
    def project_op(self, eid, opid, cell_type):
        self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
        self.candidate_flags[cell_type][eid] = False
        
    def project_edge(self, nid, eids, cell_type):
        for eid in self.nid2eids[nid]:
            if eid not in eids: # not top2
                self.proj_weights[cell_type][eid].data.fill_(0)
        self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
        self.candidate_flags_edge[cell_type][nid] = False

    #### critical function
    def get_projected_weights(self, cell_type):
        ''' used in forward and genotype '''
        weights = self.get_softmax()[cell_type]

        ## proj op
        for eid in range(self.num_edges):
            if not self.candidate_flags[cell_type][eid]:
                weights[eid].data.copy_(self.proj_weights[cell_type][eid])

        ## proj edge
        for nid in self.nid2eids:
            if not self.candidate_flags_edge[cell_type][nid]: ## projected node
                for eid in self.nid2eids[nid]:
                    if eid not in self.nid2selected_eids[cell_type][nid]:
                        weights[eid].data.copy_(self.proj_weights[cell_type][eid])

        return weights
    
    def get_all_projected_weights(self, cell_type):
        weights = self.get_softmax()[cell_type]

        for eid in range(self.num_edges):
            weights[eid].data.copy_(self.proj_weights[cell_type][eid])

        for nid in self.nid2eids:
            for eid in self.nid2eids[nid]:
                weights[eid].data.copy_(self.proj_weights[cell_type][eid])
        
        return weights

    def forward(self, input, weights_dict=None, using_proj=False):
        if using_proj:
            weights_normal = self.get_all_projected_weights('normal')
            weights_reduce = self.get_all_projected_weights('reduce')
        else:
            if weights_dict is None or 'normal' not in weights_dict:
                weights_normal = self.get_projected_weights('normal')
            else:
                weights_normal = weights_dict['normal']
            if weights_dict is None or 'reduce' not in weights_dict:
                weights_reduce = self.get_projected_weights('reduce')
            else:
                weights_reduce = weights_dict['reduce']
        


        s0 = self.stem0(input)
        s1 = self.stem1(s0)
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                weights = weights_reduce
            else:
                weights = weights_normal

            s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
            
        out = self.global_pooling(s1)
        logits = self.classifier(out.view(out.size(0),-1))

        return logits

    def reset_arch_parameters(self):
        self._initialize_flags()
        self._initialize_proj_weights()
        self._initialize_topology_dicts()

    #### utils
    def printing(self, logging, option='all'):
        weights_normal = self.get_projected_weights('normal')
        weights_reduce = self.get_projected_weights('reduce')

        if option in ['all', 'normal']:
            logging.info('\n%s', weights_normal)
        if option in ['all', 'reduce']:
            logging.info('\n%s', weights_reduce)

    def genotype(self):
        def _parse(weights, normal=True):
            PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']

            gene = []
            n = 2
            start = 0
            for i in range(self._steps):
                end = start + n
                W = weights[start:end].copy()

                try:
                    edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
                except ValueError:
                    edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]

                for j in edges:
                    k_best = None
                    for k in range(len(W[j])):
                        if 'none' in PRIMITIVES[j]:
                            if k != PRIMITIVES[j].index('none'):
                                if k_best is None or W[j][k] > W[j][k_best]:
                                    k_best = k
                        else:
                            if k_best is None or W[j][k] > W[j][k_best]:
                                k_best = k
                    gene.append((PRIMITIVES[start+j][k_best], j))
                start = end
                n += 1
            return gene

        weights_normal = self.get_projected_weights('normal')
        weights_reduce = self.get_projected_weights('reduce')
        gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
        gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)

        concat = range(2+self._steps-self._multiplier, self._steps+2)
        genotype = Genotype(
            normal=gene_normal, normal_concat=concat,
            reduce=gene_reduce, reduce_concat=concat
        )
        return genotype
    
    def get_state_dict(self, epoch, architect, scheduler):
        model_state_dict = {
            'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
            'state_dict': self.state_dict(),
            'alpha': self.arch_parameters(),
            'optimizer': self.optimizer.state_dict(),
            'arch_optimizer': architect.optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            #### projection
            'nid2eids': self.nid2eids,
            'nid2selected_eids': self.nid2selected_eids,
            'candidate_flags': self.candidate_flags,
            'candidate_flags_edge': self.candidate_flags_edge,
            'proj_weights': self.proj_weights,
        }
        return model_state_dict

    def set_state_dict(self, architect, scheduler, checkpoint):
        #### common
        self.load_state_dict(checkpoint['state_dict'])
        self.set_arch_parameters(checkpoint['alpha'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

        #### projection
        self.nid2eids = checkpoint['nid2eids']
        self.nid2selected_eids = checkpoint['nid2selected_eids']
        self.candidate_flags = checkpoint['candidate_flags']
        self.candidate_flags_edge = checkpoint['candidate_flags_edge']
        self.proj_weights = checkpoint['proj_weights']