from nets.SubsetNets.utils.utils import filter_table
from nets.Baseline.Transformer.transformer import TransformerEncoder, TransformerSeq2Seq

import numpy as np
import json

with open('nets/SubsetNets/Transformer/tables/table__Subset_x64_STransformer.json', 'r') as fd:
    _g_config_list_stransformer = json.load(fd)

with open('nets/SubsetNets/Transformer/tables/table__Subset_x64_STransformerSeq2Seq.json', 'r') as fd:
    _g_config_list_stransformerseq2seq = json.load(fd)

class STransformer(TransformerEncoder):
    def __init__(self, keep_factor=1.0):
        super().__init__(vocab_size = 16000, seq_len = 512, d_model=int(np.ceil(keep_factor*128)), n_layers=6,
                 n_heads=2, p_drop=0.05, d_ff=int(np.ceil(keep_factor*128)), pad_id=0)

        self._register_load_state_dict_pre_hook(self.sd_hook)

    @staticmethod
    def get_keep_factor(relative_resources):
        return filter_table(relative_resources, _g_config_list_stransformer)

    def sd_hook(self, state_dict, *_):
        for name, param in self.named_parameters():
            if name not in state_dict.keys():
                continue
            if len(param.size()) == 4:
                state_dict[name] = state_dict[name][0:param.shape[0],0:param.shape[1],:,:]
            elif len(param.size()) == 2:
                state_dict[name] = state_dict[name][0:param.shape[0],0:param.shape[1]]
            elif len(param.size()) == 1:
                state_dict[name] = state_dict[name][:param.shape[0]]
            else:
                raise NotImplementedError

        for name, buffer in self.named_buffers():
            if name not in state_dict.keys():
                continue
            if len(buffer.size()) == 4:
                state_dict[name] = state_dict[name][0:buffer.shape[0],0:buffer.shape[1],:,:]
            elif len(buffer.size()) == 3:
                state_dict[name] = state_dict[name][0:buffer.shape[0],:,0:buffer.shape[2]]
            elif len(buffer.size()) == 2:
                state_dict[name] = state_dict[name][0:buffer.shape[0],0:buffer.shape[1]]
            elif len(buffer.size()) == 1:
                state_dict[name] = state_dict[name][:buffer.shape[0]]
            else:
                raise NotImplementedError


class STransformerSeq2Seq(TransformerSeq2Seq):
    def __init__(self, keep_factor=1.0):
        super(TransformerSeq2Seq, self).__init__(vocab_size=81, seq_len=81, d_model=int(np.ceil(keep_factor*128)), n_layers=6,
                                             n_heads=2, p_drop=0.05, d_ff=int(np.ceil(keep_factor*128)), pad_id=80, n_classes=81)
    
        self._register_load_state_dict_pre_hook(self.sd_hook)

    @staticmethod
    def get_keep_factor(relative_resources):
        return filter_table(relative_resources, _g_config_list_stransformerseq2seq)

    def sd_hook(self, state_dict, *_):
        for name, param in self.named_parameters():
            if name not in state_dict.keys():
                continue
            if len(param.size()) == 4:
                state_dict[name] = state_dict[name][0:param.shape[0],0:param.shape[1],:,:]
            elif len(param.size()) == 2:
                state_dict[name] = state_dict[name][0:param.shape[0],0:param.shape[1]]
            elif len(param.size()) == 1:
                state_dict[name] = state_dict[name][:param.shape[0]]
            else:
                raise NotImplementedError

        for name, buffer in self.named_buffers():
            if name not in state_dict.keys():
                continue
            if len(buffer.size()) == 4:
                state_dict[name] = state_dict[name][0:buffer.shape[0],0:buffer.shape[1],:,:]
            elif len(buffer.size()) == 3:
                state_dict[name] = state_dict[name][0:buffer.shape[0],:,0:buffer.shape[2]]
            elif len(buffer.size()) == 2:
                state_dict[name] = state_dict[name][0:buffer.shape[0],0:buffer.shape[1]]
            elif len(buffer.size()) == 1:
                state_dict[name] = state_dict[name][:buffer.shape[0]]
            else:
                raise NotImplementedError