import sys
sys.path.insert(0, '../../')
import numpy as np
import torch
import logging
import torch.utils
from copy import deepcopy
from correlation.foresight.pruners import *

torch.set_printoptions(precision=4, sci_mode=False)

def sample_op(model, input, target, args, cell_type, selected_eid=None):
    ''' operation '''
    #### macros
    num_edges, num_ops = model.num_edges, model.num_ops
    candidate_flags = model.candidate_flags[cell_type]
    proj_crit = args.proj_crit[cell_type]

    #### select an edge
    if selected_eid is None:
        remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
        selected_eid = np.random.choice(remain_eids, size=1)[0]
        logging.info('selected edge: %d %s', selected_eid, cell_type)
    
    select_opid = np.random.choice(np.array(range(num_ops)), size=1)[0]
    return selected_eid, select_opid

def project_op(model, input, target, args, cell_type, proj_queue=None, selected_eid=None):
    ''' operation '''
    #### macros
    num_edges, num_ops = model.num_edges, model.num_ops
    candidate_flags = model.candidate_flags[cell_type]
    proj_crit = args.proj_crit[cell_type]
    
    #### select an edge
    if selected_eid is None:
        remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
        # print(num_edges, num_ops, remain_eids)
        if args.edge_decision == "random":
            selected_eid = np.random.choice(remain_eids, size=1)[0]
            logging.info('selected edge: %d %s', selected_eid, cell_type)
        elif args.edge_decision == 'reverse':
            selected_eid = remain_eids[-1]
            logging.info('selected edge: %d %s', selected_eid, cell_type)
        else:
            selected_eid = remain_eids[0]
            logging.info('selected node: %d %s', selected_eid, cell_type)

    #### select the best operation
    if proj_crit == 'jacob':
        crit_idx = 3
        compare = lambda x, y: x < y
    else:
        crit_idx = 0
        compare = lambda x, y: x < y

    if args.dataset == 'cifar100':
        n_classes = 100
    elif args.dataset == 'imagenet16-120':
        n_classes = 120
    else:
        n_classes = 10

    best_opid = 0
    crit_extrema = None
    crit_list = []
    op_ids = []
    for opid in range(num_ops):
        ## projection
        weights = model.get_projected_weights(cell_type)
        proj_mask = torch.ones_like(weights[selected_eid])
        proj_mask[opid] = 0
        weights[selected_eid] = weights[selected_eid] * proj_mask

        # ## proj evaluation
        # with torch.no_grad():
        #     valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
        #     crit = valid_stats
        #     crit_list.append(crit)
        #     if crit_extrema is None or compare(crit, crit_extrema):
        #         crit_extrema = crit
        #         best_opid = opid

        ## proj evaluation
        if proj_crit == 'jacob':
            crit = Jocab_Score(model,cell_type, input, target, weights=weights)
        else:
            cache_weight = model.proj_weights[cell_type][selected_eid]
            cache_flag = model.candidate_flags[cell_type][selected_eid]

            for idx in range(num_ops):
                if idx == opid:
                    model.proj_weights[cell_type][selected_eid][opid] = 0
                else:
                    model.proj_weights[cell_type][selected_eid][idx] = 1.0 / num_ops

            model.candidate_flags[cell_type][selected_eid] = False
            # print(model.get_projected_weights())
            measures = predictive.find_measures(model,
                                                proj_queue,
                                                ('random', 1, n_classes),
                                                torch.device("cuda"),
                                                measure_names=[proj_crit])

            # print(measures)
            for idx in range(num_ops):
                model.proj_weights[cell_type][selected_eid][idx] = 0
            model.candidate_flags[cell_type][selected_eid] = cache_flag
            crit = measures[proj_crit]

        crit_list.append(crit)
        op_ids.append(opid)

    best_opid = op_ids[np.nanargmin(crit_list)]



    #### project
    logging.info('best opid: %d', best_opid)
    logging.info(crit_list)
    return selected_eid, best_opid
    
def project_global_op(model, input, target, args, infer, cell_type, selected_eid=None):
    ''' operation '''
    #### macros
    num_edges, num_ops = model.num_edges, model.num_ops
    candidate_flags = model.candidate_flags[cell_type]
    proj_crit = args.proj_crit[cell_type]
    
    remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]

    #### select the best operation
    if proj_crit == 'jacob':
        crit_idx = 3
        compare = lambda x, y: x < y

    best_opid = 0
    crit_extrema = None
    best_eid = None
    for eid in remain_eids:
        for opid in range(num_ops):
            ## projection
            weights = model.get_projected_weights(cell_type)
            proj_mask = torch.ones_like(weights[eid])
            proj_mask[opid] = 0
            weights[eid] = weights[eid] * proj_mask

            ## proj evaluation
            
            #weights_dict = {cell_type:weights}
            with torch.no_grad():
                valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
                crit = valid_stats
                if crit_extrema is None or compare(crit, crit_extrema):
                    crit_extrema = crit
                    best_opid = opid
                    best_eid = eid

    #### project
    logging.info('best opid: %d', best_opid)
    #logging.info(crit_list)
    return best_eid, best_opid

def sample_edge(model, input, target, args, cell_type, selected_eid=None):
    ''' topology '''
    #### macros
    candidate_flags = model.candidate_flags_edge[cell_type]
    proj_crit = args.proj_crit[cell_type]

    #### select an node
    remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
    selected_nid = np.random.choice(remain_nids, size=1)[0]
    logging.info('selected node: %d %s', selected_nid, cell_type)
    
    eids = deepcopy(model.nid2eids[selected_nid])

    while len(eids) > 2:
        elected_eid = np.random.choice(eids, size=1)[0]
        eids.remove(elected_eid)

    return selected_nid, eids

def project_edge(model, input, target, args, cell_type):
    ''' topology '''
    #### macros
    candidate_flags = model.candidate_flags_edge[cell_type]
    proj_crit = args.proj_crit[cell_type]

    #### select an node
    remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
    if args.edge_decision == "random":
        selected_nid = np.random.choice(remain_nids, size=1)[0]
        logging.info('selected node: %d %s', selected_nid, cell_type)
    elif args.edge_decision == 'reverse':
        selected_nid = remain_nids[-1]
        logging.info('selected node: %d %s', selected_nid, cell_type)
    else:
        selected_nid = np.random.choice(remain_nids, size=1)[0]
        logging.info('selected node: %d %s', selected_nid, cell_type)
    
    #### select top2 edges
    if proj_crit == 'jacob':
        crit_idx = 3
        compare = lambda x, y: x < y
    else:
        crit_idx = 3
        compare = lambda x, y: x < y

    eids = deepcopy(model.nid2eids[selected_nid])
    crit_list = []
    while len(eids) > 2:
        eid_todel = None
        crit_extrema = None
        for eid in eids:
            weights = model.get_projected_weights(cell_type)
            weights[eid].data.fill_(0)

            ## proj evaluation
            with torch.no_grad():
                valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
                crit = valid_stats

                crit_list.append(crit)
                if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
                    crit_extrema = crit
                    eid_todel = eid

        eids.remove(eid_todel)

    #### project
    logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
    #logging.info(crit_list)
    return selected_nid, eids


def pt_project(train_queue, model, args):
    model.eval()

    #### macros
    num_projs = model.num_edges + len(model.nid2eids.keys()) 
    args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
    proj_queue = train_queue

    epoch = 0
    for step, (input, target) in enumerate(proj_queue):
        if epoch < model.num_edges:
            logging.info('project op')
            
            if args.edge_decision == 'global_op_greedy':
                selected_eid_normal, best_opid_normal = project_global_op(model, input, target, args, cell_type='normal')
            elif args.edge_decision == 'sample':
                selected_eid_normal, best_opid_normal  = sample_op(model, input, target, args, cell_type='normal')
            else:
                selected_eid_normal, best_opid_normal = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='normal')
            model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
            if args.edge_decision == 'global_op_greedy':
                selected_eid_reduce, best_opid_reduce = project_global_op(model, input, target, args, cell_type='reduce')
            elif args.edge_decision == 'sample':
                selected_eid_reduce, best_opid_reduce  = sample_op(model, input, target, args, cell_type='reduce')
            else:
                selected_eid_reduce, best_opid_reduce = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='reduce')
            model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')

        else:
            logging.info('project edge')
            if args.edge_decision == 'sample':
                selected_nid_normal, eids_normal = sample_edge(model, input, target, args, cell_type='normal')
                model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
                selected_nid_reduce, eids_reduce = sample_edge(model, input, target, args, cell_type='reduce')
                model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
            else:
                selected_nid_normal, eids_normal = project_edge(model, input, target, args, cell_type='normal')
                model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
                selected_nid_reduce, eids_reduce = project_edge(model, input, target, args, cell_type='reduce')
                model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
        epoch+=1

        if epoch == num_projs:
            break

    return

def Jocab_Score(ori_model, cell_type, input, target, weights=None):
    model = deepcopy(ori_model)
    model.eval()
    if cell_type == 'reduce':
        model.proj_weights['reduce'] = weights
        model.proj_weights['normal'] = model.get_projected_weights('normal')
    else:
        model.proj_weights['normal'] = weights
        model.proj_weights['reduce'] = model.get_projected_weights('reduce')

    batch_size = input.shape[0]
    model.K = torch.zeros(batch_size, batch_size).cuda()
    def counting_forward_hook(module, inp, out):
        try:
            if isinstance(inp, tuple):
                inp = inp[0]
            inp = inp.view(inp.size(0), -1)
            x = (inp > 0).float()
            K = x @ x.t()
            K2 = (1.-x) @ (1.-x.t())
            model.K = model.K + K + K2
        except:
            pass

    for name, module in model.named_modules():
        if 'ReLU' in str(type(module)):
            module.register_forward_hook(counting_forward_hook)
    
    input = input.cuda()

    model(input, using_proj=True)
    score = hooklogdet(model.K.cpu().numpy())

    del model
    return score

def hooklogdet(K, labels=None):
    s, ld = np.linalg.slogdet(K)
    return ld