import torch
import torch.nn as nn
import torch.optim as optim
from utils.options import args
import copy

import os
from importlib import import_module
from model.resnet_2 import BasicBlock, Bottleneck
from utils.common import cluster_weight, random_project, direct_project

def cluster_resnet(oristate_dict, cfg, num_classes, init_state_dict=None):
    ori_honey = {
        'resnet18': [64] + [64] * 2 + [128] * 2 + [256] * 2 + [512] * 2,
        'resnet34': [64] + [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3,
        'resnet50': [64] + [64] * 3 * 2 + [128] * 4 * 2 + [256] * 6 * 2 + [512] * 3 * 2,
        'resnet101': [64] + [64] * 3 * 2 + [128] * 4 * 2 + [256] * 23 * 2 + [512] * 3 * 2,
        'resnet152': [64] + [64] * 3 * 2 + [128] * 8 * 2 + [256] * 36 * 2 + [512] * 3 * 2
    }
    origin_model = import_module('model.resnet_2').resnet(cfg, honey=ori_honey[cfg],  num_classes=num_classes)
    origin_model.load_state_dict(oristate_dict, strict=False)

    honey = []
    centroids_state_dict = {}
    prune_state_dict = []
    indices = []

    block_num_betas = {
        'resnet18': (6, 0.86),
        'resnet34': (14, 0.92),
        'resnet50': (13, 0.85),
        'resnet101': (28, 0.92),
        'resnet152': (45, 0.92)
    }

    current_conv_layer_index = 0

    for name, module in origin_model.named_modules():
        if isinstance(module, BasicBlock):
            conv1_weight = module.conv1.weight.data
            if current_conv_layer_index >= block_num_betas[cfg][0]:
                _, centroids, indice = cluster_weight(conv1_weight, block_num_betas[cfg][1])
            else:
                _, centroids, indice = cluster_weight(conv1_weight, args.preference_beta)
            current_conv_layer_index += 1
            honey.append(len(centroids))

            if init_state_dict == None:
                centroids_state_dict[name + '.conv1.weight'] = centroids
                if args.init_method == 'random_project':
                    centroids_state_dict[name + '.conv2.weight'] = random_project(module.conv2.weight.data, len(centroids))
                else:
                    centroids_state_dict[name + '.conv2.weight'] = direct_project(module.conv2.weight.data, indice)
            else:
                for i, j in enumerate(indice):
                    centroids[i, :] = init_state_dict[name + '.conv1.weight'][j, :, :, :].cpu().numpy().reshape(1,-1)
                centroids_state_dict[name + '.conv1.weight'] = centroids
                if args.init_method == 'random_project':
                    centroids_state_dict[name + '.conv2.weight'] = random_project(init_state_dict[name + '.conv2.weight'], len(centroids))
                else:
                    centroids_state_dict[name + '.conv2.weight'] = direct_project(init_state_dict[name + '.conv2.weight'], indice)

            prune_state_dict.append(name + '.bn1.weight')
            prune_state_dict.append(name + '.bn1.bias')
            prune_state_dict.append(name + '.bn1.running_var')
            prune_state_dict.append(name + '.bn1.running_mean')

        elif isinstance(module, Bottleneck):
            conv1_weight = module.conv1.weight.data
            if current_conv_layer_index >= block_num_betas[cfg][0]:
                _, centroids, indice = cluster_weight(conv1_weight, block_num_betas[cfg][1] if cfg == 'resnet50' and args.preference_beta==0.7 else None)
            else:
                _, centroids, indice = cluster_weight(conv1_weight, args.preference_beta)
            honey.append(len(centroids))
            indices.append(indice)
            if init_state_dict == None:
                pass
            else:
                for i, j in enumerate(indice):
                    centroids[i, :] = init_state_dict[name + '.conv1.weight'][j, :, :, :].cpu().numpy().reshape(1,-1)
            centroids_state_dict[name + '.conv1.weight'] = centroids

            prune_state_dict.append(name + '.bn1.weight')
            prune_state_dict.append(name + '.bn1.bias')
            prune_state_dict.append(name + '.bn1.running_var')
            prune_state_dict.append(name + '.bn1.running_mean')

            conv2_weight = module.conv2.weight.data
            if current_conv_layer_index >= block_num_betas[cfg][0]:
                _, centroids, indice = cluster_weight(conv2_weight, block_num_betas[cfg][1])
            else:
                _, centroids, indice = cluster_weight(conv2_weight, args.preference_beta)
            current_conv_layer_index += 1
            honey.append(len(centroids))
            if init_state_dict == None:
                pass
            else:
                for i, j in enumerate(indice):
                    centroids[i, :] = init_state_dict[name + '.conv2.weight'][j, :, :, :].cpu().numpy().reshape(1,-1)
            centroids_state_dict[name + '.conv2.weight'] = centroids.reshape((-1, conv2_weight.size(1), conv2_weight.size(2), conv2_weight.size(3)))

            if init_state_dict == None:
                if args.init_method == 'random_project':
                    centroids_state_dict[name + '.conv3.weight'] = random_project(module.conv3.weight.data, len(centroids))
                else:
                    centroids_state_dict[name + '.conv3.weight'] = direct_project(module.conv3.weight.data, indice)
            else:
                if args.init_method == 'random_project':
                    centroids_state_dict[name + '.conv3.weight'] = random_project(init_state_dict[name + '.conv3.weight'], len(centroids))
                else:
                    centroids_state_dict[name + '.conv3.weight'] = direct_project(init_state_dict[name + '.conv3.weight'], indice)

            prune_state_dict.append(name + '.bn2.weight')
            prune_state_dict.append(name + '.bn2.bias')
            prune_state_dict.append(name + '.bn2.running_var')
            prune_state_dict.append(name + '.bn2.running_mean')

    honey.insert(0 , 64)
    model = import_module('model.resnet_2').resnet(cfg, honey=copy.deepcopy(honey), 
                                                    num_classes=num_classes)
    if args.init_method == 'random_project' or args.init_method == 'centroids':
        pretrain_state_dict = origin_model.state_dict()
        state_dict = model.state_dict()
        centroids_state_dict_keys = list(centroids_state_dict.keys())

        index = 0
        for k, v in centroids_state_dict.items():

            if k.endswith('.conv2.weight') and cfg != 'resnet18' and cfg != 'resnet34':
                if args.init_method == 'random_project':
                    centroids_state_dict[k] = random_project(torch.FloatTensor(centroids_state_dict[k]),
                                                            len(indices[index]))
                else:
                    centroids_state_dict[k] = direct_project(torch.FloatTensor(centroids_state_dict[k]), indices[index])
                index += 1

        for k, v in state_dict.items():
            if k in prune_state_dict:
                continue
            elif k in centroids_state_dict_keys:
                state_dict[k] = torch.FloatTensor(centroids_state_dict[k]).view_as(state_dict[k])
            else:
                if init_state_dict == None:
                    state_dict[k] = pretrain_state_dict[k]
                else:
                    state_dict[k] = init_state_dict[k]
        model.load_state_dict(state_dict)
    else:
        pass
    return model, honey