import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from dyn_module import Conv_DyN, Linear_DyN, MultiheadAttention_DyN, Linear_DyN_NoMat

#!/usr/bin/env python
# coding: utf-8

import os
import os.path as osp
import json
from datetime import datetime 
import torch
import torch.nn as nn
import time
import random
import torchvision.models as models

import copy

def get_module_by_name(model, module_name):
    name_list = module_name.split(".")
    for name in name_list[:-1]:
        if hasattr(model, name):
            model = getattr(model, name)
        else:
            return None, None
    if hasattr(model, name_list[-1]):
        leaf_module = getattr(model, name_list[-1])
        return model, leaf_module
    else:
        return None, None

def update_module(model, module_name, new_module):
    super_module, leaf_module = get_module_by_name(model, module_name)
    setattr(super_module, module_name.split('.')[-1], new_module)

def replace_conv2d(model, num_CHs, q_dim, norm_p, SCALE_FACTOR_conv, vars_set=None, from_pretrain=False, pretrain_root=""):
    pretrain_dyns_dict = {}
    if from_pretrain:
        pretrain_dyns = os.listdir(pretrain_root)
        for filename in pretrain_dyns:
            module_name = filename.split('#')[0].split('.weight')[0]
            pretrain_dyns_dict[module_name] = torch.load(osp.join(pretrain_root, filename))
    conv_layers = {}
    if vars_set is None:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.modules.conv.Conv2d):
                conv_layers[name] = {
                    'kernel_size': module.kernel_size[0],
                    'in_channels': module.in_channels,
                    'out_channels': module.out_channels,
                    'stride': module.stride[0],
                    'padding': module.padding[0],
                }
    else:
        for name, module in model.named_modules():
            print(name)
            if isinstance(module, torch.nn.modules.conv.Conv2d) and (name in vars_set):
                conv_layers[name] = {
                    'kernel_size': module.kernel_size[0],
                    'in_channels': module.in_channels,
                    'out_channels': module.out_channels,
                    'stride': module.stride[0],
                    'padding': module.padding[0],
                }
    for name in conv_layers.keys():
        conv_dyn = Conv_DyN(
            conv_layers[name]['kernel_size'],
            conv_layers[name]['in_channels'],
            conv_layers[name]['out_channels'],
            conv_layers[name]['stride'],
            conv_layers[name]['padding'],
            num_CHs,
            q_dim, 
            norm_p, 
            SCALE_FACTOR_conv
            )
        if from_pretrain:
            print(f'load pretrained dyn from {pretrain_root}')
            conv_dyn.load_state_dict(pretrain_dyns_dict[name])
        update_module(model, name, conv_dyn)
    return model

def replace_linear(model, num_CHs, q_dim, norm_p, SCALE_FACTOR_conv, vars_set=None, from_pretrain=False, pretrain_root=""):
    pretrain_dyns_dict = {}
    if from_pretrain:
        pretrain_dyns = os.listdir(pretrain_root)
        for filename in pretrain_dyns:
            module_name = filename.split('#')[0].split('.weight')[0]
            pretrain_dyns_dict[module_name] = torch.load(osp.join(pretrain_root, filename))
    layers = {}
    if vars_set is None:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.modules.linear.Linear):
                layers[name] = {
                    'in_features': module.in_features,
                    'out_features': module.out_features,
                }
    else:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.modules.linear.Linear) and (name in vars_set):
                layers[name] = {
                    'in_features': module.in_features,
                    'out_features': module.out_features,
                }
    for name in layers.keys():
        linear_dyn = Linear_DyN(
            layers[name]['in_features'],
            layers[name]['out_features'],
            num_CHs,
            q_dim, 
            norm_p, 
            SCALE_FACTOR_conv
            )
        if from_pretrain:
            print(f'load pretrained dyn from {pretrain_root}')
            linear_dyn.load_state_dict(pretrain_dyns_dict[name])
        update_module(model, name, linear_dyn)
    return model


def replace_linear_nomat(model, num_CHs, q_dim, norm_p, SCALE_FACTOR_conv, vars_set=None, from_pretrain=False, pretrain_root=""):
    pretrain_dyns_dict = {}
    if from_pretrain:
        pretrain_dyns = os.listdir(pretrain_root)
        for filename in pretrain_dyns:
            module_name = filename.split('#')[0].split('.weight')[0]
            pretrain_dyns_dict[module_name] = torch.load(osp.join(pretrain_root, filename))
    layers = {}
    if vars_set is None:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.modules.linear.Linear) and (name != "lm_head"):
                layers[name] = {
                    'in_features': module.in_features,
                    'out_features': module.out_features,
                }
    else:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.modules.linear.Linear) and (name in vars_set):
                layers[name] = {
                    'in_features': module.in_features,
                    'out_features': module.out_features,
                }
    for name in layers.keys():
        linear_dyn = Linear_DyN_NoMat(
            layers[name]['in_features'],
            layers[name]['out_features'],
            num_CHs,
            q_dim, 
            norm_p, 
            SCALE_FACTOR_conv
            )
        if from_pretrain:
            print(f'load pretrained dyn from {pretrain_root}')
            linear_dyn.load_state_dict(pretrain_dyns_dict[name])
        update_module(model, name, linear_dyn)
    return model


def replace_multihead_attention(model, num_CHs, q_dim, norm_p, SCALE_FACTOR_conv, vars_set=None, from_pretrain=False, pretrain_root=""):
    pretrain_dyns_dict = {}
    if from_pretrain:
        pretrain_dyns = os.listdir(pretrain_root)
        for filename in pretrain_dyns:
            module_name = filename.split('#')[0].split('.weight')[0]
            pretrain_dyns_dict[module_name] = torch.load(osp.join(pretrain_root, filename))


    mulatt_layers = {}
    if vars_set is None:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.modules.activation.MultiheadAttention):
                mulatt_layers[name] = {
                    'embed_dim': module.embed_dim,
                    'num_heads': module.num_heads,
                    'dropout': module.dropout,
                    'bias': True if module.out_proj.bias is not None else False,
                    'add_bias_kv': True if module.bias_k is not None else False,
                    'add_zero_attn': module.add_zero_attn,
                    'kdim': module.kdim,
                    'vdim': module.vdim,
                    'batch_first': module.batch_first
                }
    else:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.modules.activation.MultiheadAttention) and (name in vars_set):
                mulatt_layers[name] = {
                    'embed_dim': module.embed_dim,
                    'num_heads': module.num_heads,
                    'dropout': module.dropout,
                    'bias': True if module.out_proj.bias is not None else False,
                    'add_bias_kv': True if module.bias_k is not None else False,
                    'add_zero_attn': module.add_zero_attn,
                    'kdim': module.kdim,
                    'vdim': module.vdim,
                    'batch_first': module.batch_first
                }
    
    for name in mulatt_layers.keys():
        mulatt_dyn = MultiheadAttention_DyN(
            mulatt_layers[name]['embed_dim'],
            mulatt_layers[name]['num_heads'],
            num_CHs,
            q_dim, 
            norm_p, 
            SCALE_FACTOR_conv,
            mulatt_layers[name]['dropout'],
            mulatt_layers[name]['bias'],
            mulatt_layers[name]['add_bias_kv'],
            mulatt_layers[name]['add_zero_attn'],
            mulatt_layers[name]['kdim'],
            mulatt_layers[name]['vdim'],
            mulatt_layers[name]['batch_first'],
            )
        if from_pretrain:
            print(f'load pretrained dyn from {pretrain_root}')
            for submodule_name, module in mulatt_dyn.named_children():
                print('load {}.{}'.format(name, submodule_name))
                module.load_state_dict(pretrain_dyns_dict["{}.{}".format(name, submodule_name)])
        update_module(model, name, mulatt_dyn)
    return model