import os
from scipy.stats import truncnorm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from params_remapping import remap_copy_paste, random_init, uniform, l1norm_sorting, pca_init


class PR():
    '''
    Parameter remapping stage-wisely for whole st network
    '''
    def __init__(self, device, n_stage, tc_net, st_net, \
                st_depth_config, st_channel_widths, pr_type, args):
        ## General
        self.args = args
        self.device = device
        self.image_size = args.image_size
        
        ## PR type
        self.pr_type = pr_type

        ## Search Space setting
        self.tc_net_name = args.tc_net_name
        self.tc_stage_depth = args.tc_stage_depth
        self.base_channel_dict = args.base_channel_dict
        self.base_channels = list(self.base_channel_dict.values())
        self.exclude_first_block = args.exclude_first_block
        
        ## Teacher Net
        self.n_stage = n_stage
        self.tc_stage_default_channel_widths = args.tc_stage_default_channel_widths
        self.channel_mul = args.channel_mul
        self.tc_stage_channel_widths = [int(w*self.channel_mul) for w in self.tc_stage_default_channel_widths]
        
        self.tc_net = tc_net
        self.tc_net.eval()
        self.tc_stages = [self.tc_net.layer1,
                            self.tc_net.layer2,
                            self.tc_net.layer3,
                            self.tc_net.layer4]
        self.tc_stages = [_.to(device) for _ in self.tc_stages]
        self.tc_stage_params = [self._get_stage_params(stage) for stage in self.tc_stages]

        ## Student Net
        self.st_net = st_net
        self.st_stages = [self.st_net.layer1,
                            self.st_net.layer2,
                            self.st_net.layer3,
                            self.st_net.layer4]
        self.st_stages = [_.to(device) for _ in self.st_stages]
        self.st_stage_params = [self._get_stage_params(stage) for stage in self.st_stages]
        self.st_depth_config = st_depth_config
        self.st_channel_widths = st_channel_widths
        self.st_channel_expansions = []
        for i in range(self.n_stage):
            channel_width = self.st_channel_widths[i]
            base_channel = self.base_channels[i]
            self.st_channel_expansions.append([int(_/base_channel) for _ in channel_width])
        
        self.metapr_ver = args.metapr_ver if 'metapr_ver' in args.__dict__.keys() else 0

    def _get_stage_params(self, stage):
        stage_params = defaultdict(list)
        for block in stage:
            for k, v in block.state_dict().items():
                if 'conv' in k:
                    stage_params['conv'].append(v)
                elif 'bn' in k:
                    # weight, bias, running_mean, running_var
                    stage_params[k[4:]].append(v)
                else: raise ValueError(k)        
        return stage_params
        

    def param_remapping(self):
        # print(f'pr type is {self.pr_type}...')

        st_dict_lists = []
        for i in range(self.n_stage):
            tc_stage_params = self.tc_stage_params[i]
            st_stage_params = self.st_stage_params[i]
            depth = self.st_depth_config[i]
            base_channel = self.base_channels[i]

            if self.pr_type is None: # No PR
                ## Build student state_dict
                st_dict_list = []
                num_layer = 2
                for d in range(self.st_depth_config[i]):
                    st_dict = {}
                    for j in range(num_layer):
                        st_dict[f'conv{j+1}.weight'] = self.st_stage_params[i]['conv'][d*num_layer+j]
                        for key in ['weight', 'bias', 'running_mean', 'running_var']:
                            st_dict[f'bn{j+1}.{key}'] = self.st_stage_params[i][key][d*num_layer+j]
                        st_dict[f'bn{j+1}.num_batches_tracked'] = self.tc_stage_params[i]['num_batches_tracked'][0] 
                    st_dict_list.append(st_dict)
            
            elif self.pr_type == 'random_init':
                random_type = 'kaiming_normal'
                st_dict_list = random_init(tc_stage_params, st_stage_params, depth, random_type)
            
            elif self.pr_type == 'uniform':
                st_dict_list = uniform(tc_stage_params, st_stage_params, depth, base_channel, self.exclude_first_block)
                    
            elif self.pr_type == 'l1norm_sorting':
                st_dict_list = l1norm_sorting(tc_stage_params, st_stage_params, depth)
            
            elif 'copy_paste' in self.pr_type:
                st_dict_list = remap_copy_paste(tc_stage_params, st_stage_params, depth, btype=self.pr_type)
            
            elif 'pca_init' == self.pr_type:
                # TODO
                st_dict_list = pca_init(task, exclude_first_block=self.exclude_first_block)
            
            elif 'pca_init_reverse' == self.pr_type:
                # TODO
                st_dict_list = pca_init(task, exclude_first_block=self.exclude_first_block, reverse=True)

            else: raise ValueError(self.pr_type)
            
            st_dict_lists.append(st_dict_list)

        return st_dict_lists

