from collections import defaultdict
from sklearn.decomposition import PCA
import torch
import torch.nn as nn


BASELINE_TYPE = ['random_init', 'mean', 'uniform', 'l1norm_sorting', 'copy_paste_first']

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def random_init(tc_stage_params, st_stage_params, depth, random_type='kaiming_normal'):
    '''
    initialize the student block randomly
    '''

    st_params = defaultdict(list)

    for i, v in enumerate(st_stage_params['conv']):
        st_params['conv'].append(nn.init.kaiming_normal_(v, mode='fan_out', nonlinearity="relu").to(device))

    for key in ['weight', 'bias', 'running_mean', 'running_var']:
        for i, v in enumerate(st_stage_params[key]):
            # st_params[key].append(torch.nn.init.kaiming_normal(v, mode='fan_out'))
            st_params[key].append(v.to(device))
            # st_params[key].append(torch.rand(v.size()).to(device))
    
    ## Build student state_dict
    '''
    st_params: {
        'conv': ...,
        'weight': ...,
        'bias': ...,
        'running_mean': ...,
        'running_var': ...
    }
    st_dict_list = [
        {'conv1.weight': , 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked',
         'conv2.weight': , 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked'} # for depth 1,
        {'conv1.weight': , 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked',
         'conv2.weight': , 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked'} # for depth 1,
         ...
    ]
    '''
    st_dict_list = []
    num_layer = 2
    for d in range(depth):
        st_dict = {}
        for j in range(num_layer):
            st_dict[f'conv{j+1}.weight'] = st_params['conv'][d*num_layer+j]
            for key in ['weight', 'bias', 'running_mean', 'running_var']:
                st_dict[f'bn{j+1}.{key}'] = st_params[key][d*num_layer+j]
            # try:
            # st_dict[f'bn{j+1}.num_batches_tracked'] = tc_stage_params['num_batches_tracked']
            st_dict[f'bn{j+1}.num_batches_tracked'] = tc_stage_params['num_batches_tracked'][0] 
            # except:
            #     import pdb; pdb.set_trace()
        st_dict_list.append(st_dict)
    
    return st_dict_list


def uniform(tc_stage_params, st_stage_params, depth, base_channel, exclude_first_block=True):
    '''
    uniform remapping chunk-by-chunk
    '''

    st_out_expansions = []
    for params in st_stage_params['conv'][exclude_first_block:]:
        out_ch, in_ch, k, _ = params.size()
        st_out_expansion = int(out_ch / base_channel)
        st_out_expansions.append(st_out_expansion)
        st_out_expansion_sum = sum(st_out_expansions)

    st_params = defaultdict(list)

    ## first_conv 1: copy & paste
    if exclude_first_block:
        for key in ['conv', 'weight', 'bias', 'running_mean', 'running_var']:
            # uniform sampling
            out_s = int(len(tc_stage_params[key][0]) / len(st_stage_params[key][0]))
            params = tc_stage_params[key][0][::out_s]
            if len(params) > len(st_stage_params[key][0]):
                params = params[:len(st_stage_params[key][0])]

            if key == 'conv':
                in_s = int(tc_stage_params[key][0].size(1) / st_stage_params[key][0].size(1))
                params = params[:, ::in_s]
                if params.size(1) > st_stage_params[key][0].size(1):
                    params = params[:, :st_stage_params[key][0].size(1)]
                st_params[key].append(params)
            else:
                st_params[key].append(params)
            
    for key in ['conv', 'weight', 'bias', 'running_mean', 'running_var']:
        if key == 'conv':
            converted_params = torch.cat(tc_stage_params['conv'][exclude_first_block:], dim=0)
            _, in_ch, k, k = converted_params.size()
            converted_params = converted_params.view(-1, base_channel, in_ch, k, k)
        else:
            converted_params = torch.cat(tc_stage_params[key][exclude_first_block:], dim=0).view(-1, base_channel) 
            
        out_s = int(converted_params.size(0) / st_out_expansion_sum)
        sampled_params = converted_params[::out_s]
        if sampled_params.size(0) > st_out_expansion_sum:
            sampled_params = sampled_params[:st_out_expansion_sum]

        cur = 0
        for i, se in enumerate(st_out_expansions):
            mapped_params = sampled_params[cur:cur+se]
            if key == 'conv':
                _, _, in_ch, k, k = mapped_params.size()
                mapped_params = mapped_params.contiguous().view(-1, in_ch, k, k)
                st_in_ch = st_stage_params['conv'][exclude_first_block:][i].size(1)
                in_s = int(in_ch / st_in_ch)
                mapped_params = mapped_params[:, ::in_s]
                if mapped_params.size(1) > st_in_ch:
                    mapped_params = mapped_params[:, :st_in_ch]
            else:
                mapped_params = mapped_params.contiguous().view(st_stage_params[key][exclude_first_block:][i].size())
            cur += se
            st_params[key].append(mapped_params)

    ## Build student state_dict
    st_dict_list = []
    num_layer = 2
    for d in range(depth):
        st_dict = {}
        for j in range(num_layer):
            st_dict[f'conv{j+1}.weight'] = st_params['conv'][d*num_layer+j]
            for key in ['weight', 'bias', 'running_mean', 'running_var']:
                st_dict[f'bn{j+1}.{key}'] = st_params[key][d*num_layer+j]
            st_dict[f'bn{j+1}.num_batches_tracked'] = tc_stage_params['num_batches_tracked'][0] 
        st_dict_list.append(st_dict)

    return st_dict_list


def _tc_stage_params_sorting(tc_stage_params):

    sorted_tc_stage_params = defaultdict(list)       
    
    for i, v in enumerate(tc_stage_params['conv']):
        out_ch, in_ch, k, k = v.size()
        l1norm = torch.sum(torch.abs(v.view(out_ch, -1)), dim=1) # (out_ch, )
        _, indices = torch.sort(l1norm, descending=True)
        sorted_tc_stage_params['conv'].append(v.view(out_ch, -1)[indices, :].view(out_ch, in_ch, k, k))

    for key in ['weight', 'bias', 'running_mean', 'running_var']:
        for i, v in enumerate(tc_stage_params[key]):
            l1norm = torch.abs(v)
            _, indices = torch.sort(l1norm, descending=True)
            sorted_tc_stage_params[key].append(v[indices])
    
    return sorted_tc_stage_params


def l1norm_sorting(tc_stage_params, st_stage_params, depth):

    sorted_tc_stage_params = _tc_stage_params_sorting(tc_stage_params)
    
    st_params = defaultdict(list)
    for i, v in enumerate(st_stage_params['conv']):
        out_ch, in_ch, k, k = v.size()
        st_params['conv'].append(sorted_tc_stage_params['conv'][i][:out_ch, :in_ch])
    
    for key in ['weight', 'bias', 'running_mean', 'running_var']:
        for i, v in enumerate(st_stage_params[key]):
            st_params[key].append(sorted_tc_stage_params[key][i][:len(st_stage_params[key][i])])

    ## Build student state_dict
    st_dict_list = []
    num_layer = 2
    for d in range(depth):
        st_dict = {}
        for j in range(num_layer):
            st_dict[f'conv{j+1}.weight'] = st_params['conv'][d*num_layer+j]
            for key in ['weight', 'bias', 'running_mean', 'running_var']:
                st_dict[f'bn{j+1}.{key}'] = st_params[key][d*num_layer+j]
            st_dict[f'bn{j+1}.num_batches_tracked'] = tc_stage_params['num_batches_tracked'][0] 
        st_dict_list.append(st_dict)

    return st_dict_list


def _remap_copy_paste_first(tc_stage_params, st_stage_params):

    st_params = defaultdict(list)
    for i, v in enumerate(st_stage_params['conv']):
        out_ch, in_ch, k, k = v.size()
        st_params['conv'].append(tc_stage_params['conv'][i][:out_ch, :in_ch])
    
    for key in ['weight', 'bias', 'running_mean', 'running_var']:
        for i, v in enumerate(st_stage_params[key]):
            st_params[key].append(tc_stage_params[key][i][:len(st_stage_params[key][i])])

    return st_params


def _remap_copy_paste_last(tc_stage_params, st_stage_params):
    
    offset = len(tc_stage_params['conv']) - len(st_stage_params['conv'])

    st_params = defaultdict(list)
    for i, v in enumerate(st_stage_params['conv']):
        out_ch, in_ch, k, k = v.size()
        st_params['conv'].append(tc_stage_params['conv'][offset+i][-out_ch:, -in_ch:])
    
    for key in ['weight', 'bias', 'running_mean', 'running_var']:
        for i, v in enumerate(st_stage_params[key]):
            st_params[key].append(tc_stage_params[key][offset+i][-len(st_stage_params[key][i]):])

    return st_params


def _remap_copy_paste_uniform(tc_stage_params, st_stage_params, btype):
    offset = 0
    if 'last' in btype:
        offset = len(tc_stage_params['conv']) - len(st_stage_params['conv'])
    
    st_params = defaultdict(list)
    for i, v in enumerate(st_stage_params['conv']):
        out_ch, in_ch, k, k = v.size()
        t_out_ch, t_in_ch, _, _ = tc_stage_params['conv'][offset+i].size()
        out_s = int(t_out_ch / out_ch)
        in_s = int(t_in_ch / in_ch)
        params = tc_stage_params['conv'][offset+i][::out_s]
        if params.size(0) > out_ch:
            params = params[:out_ch]
        params = params[:, ::in_s]
        if params.size(1) > in_ch:
            params = params[:, :in_ch]
        st_params['conv'].append(params)
        # st_params['conv'].append(tc_stage_params['conv'][offset+i][::out_s, ::in_s])
    
    for key in ['weight', 'bias', 'running_mean', 'running_var']:
        for i, v in enumerate(st_stage_params[key]):
            s = int(len(tc_stage_params[key][offset+i]) / len(st_stage_params[key][i]))
            params = tc_stage_params[key][offset+i][::s]
            if params.size(0) > len(st_stage_params[key][i]):
                params = params[:len(st_stage_params[key][i])]
            st_params[key].append(params)
            # st_params[key].append(tc_stage_params[key][offset+i][::s])

    return st_params


def remap_copy_paste(tc_stage_params, st_stage_params, depth, btype):

    if btype == 'copy_paste_first':
        st_params = _remap_copy_paste_first(tc_stage_params, st_stage_params)
    elif btype == 'copy_paste_last':
        st_params = _remap_copy_paste_last(tc_stage_params, st_stage_params)
    elif 'copy_paste_uniform' in btype:
        st_params = _remap_copy_paste_uniform(tc_stage_params, st_stage_params, btype)
    else:
        raise ValueError(btype)
    ## Build student state_dict
    st_dict_list = []
    num_layer = 2
    for d in range(depth):
        st_dict = {}
        for j in range(num_layer):
            st_dict[f'conv{j+1}.weight'] = st_params['conv'][d*num_layer+j]
            for key in ['weight', 'bias', 'running_mean', 'running_var']:
                st_dict[f'bn{j+1}.{key}'] = st_params[key][d*num_layer+j]
            st_dict[f'bn{j+1}.num_batches_tracked'] = tc_stage_params['num_batches_tracked'][0] 
        st_dict_list.append(st_dict)
    
    return st_dict_list


def pca_init(task, exclude_first_block=True, reverse=False):
    '''
    PCA-based initialization
    '''
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    ds_split, x_input, tc_block_features, tc_stage_params, \
        stage, st_stage_params, ks, depth, channel_width, base_channel, st_expansion = task

    st_out_expansions = []
    for params in st_stage_params['conv'][exclude_first_block:]:
        out_ch, in_ch, k, _ = params.size()
        st_out_expansion = int(out_ch / base_channel)
        st_out_expansions.append(st_out_expansion)
        st_out_expansion_sum = sum(st_out_expansions)

    st_params = defaultdict(list)

    ## first_conv 1: copy & paste
    if exclude_first_block:
        for key in ['conv', 'weight', 'bias', 'running_mean', 'running_var']:
            # uniform sampling
            out_s = int(len(tc_stage_params[key][0]) / len(st_stage_params[key][0]))
            params = tc_stage_params[key][0][::out_s]
            if len(params) > len(st_stage_params[key][0]):
                params = params[:len(st_stage_params[key][0])]

            if key == 'conv':
                in_s = int(tc_stage_params[key][0].size(1) / st_stage_params[key][0].size(1))
                params = params[:, ::in_s]
                if params.size(1) > st_stage_params[key][0].size(1):
                    params = params[:, :st_stage_params[key][0].size(1)]
                st_params[key].append(params)
            else:
                st_params[key].append(params)

    key = 'conv'
    converted_params = torch.cat(tc_stage_params[key][exclude_first_block:], dim=0)
    _, in_ch, k, k = converted_params.size()
    converted_params = converted_params.view(converted_params.size(0), -1)
    converted_params = torch.transpose(converted_params, 0, 1)

    pca = PCA(n_components=st_out_expansion_sum*base_channel)
    converted_params = pca.fit_transform(converted_params.cpu().numpy())
    converted_params = torch.transpose(torch.tensor(converted_params), 0, 1)
    if reverse:
        torch.flip(converted_params, [0, 1])
    sampled_params = converted_params.view(st_out_expansion_sum*base_channel, in_ch, k, k).to(device)
    cur = 0
    for i, se in enumerate(st_out_expansions):
        mapped_params = sampled_params[cur:cur+se*base_channel]
        st_in_ch = st_stage_params['conv'][exclude_first_block:][i].size(1)
        in_s = int(in_ch / st_in_ch)
        mapped_params = mapped_params[:, ::in_s]
        if mapped_params.size(1) > st_in_ch:
            mapped_params = mapped_params[:, :st_in_ch]
        cur += se * base_channel
        st_params[key].append(mapped_params.to(device))

    for key in ['weight', 'bias', 'running_mean', 'running_var']:
        for i, v in enumerate(st_stage_params[key][exclude_first_block:]):
            st_params[key].append(v.to(device))

    ## Build student state_dict
    st_dict_list = []
    num_layer = 2
    for d in range(depth):
        st_dict = {}
        for j in range(num_layer):
            st_dict[f'conv{j+1}.weight'] = st_params['conv'][d*num_layer+j]
            for key in ['weight', 'bias', 'running_mean', 'running_var']:
                st_dict[f'bn{j+1}.{key}'] = st_params[key][d*num_layer+j]
            st_dict[f'bn{j+1}.num_batches_tracked'] = tc_stage_params['num_batches_tracked'][0] 
        st_dict_list.append(st_dict)

    return st_dict_list