import os
import torch

from lbmqt.conf import config
from lbmqt.nn.module import QModule
from lbmqt.optim import SGD, Adam, AdamW
from lbmqt.utils import GOR


class LowBitEngine(torch.nn.Module):
    def __init__(
        self,
        args,
        model,
        optimizer,
        name_groups,
    ):
        super(LowBitEngine, self).__init__()
        self.args = args

        self._configure_model(model)

        named_parameters = self.module.get_named_model_parameters()
        param_scheme = self.module.scheme
        self._configure_optimizer(optimizer, named_parameters, param_scheme, name_groups)

        self._configure_recorder()

    def _configure_model(self, model):
        # move to cuda and quantize parameters, rebuild model
        model = QModule(model)
        # set model
        modules = self.__dict__.get('_modules')
        modules['module'] = model
        # register module attribute in engine but avoid getattr
        self.__dict__['module'] = model
        print(f'Engine configuring model ended')

    def _configure_optimizer(self, optimizer_name, named_parameters, param_scheme, name_groups):
        param_groups = self._configure_param_groups(named_parameters, param_scheme, name_groups)
        if optimizer_name == 'sgd':
            optimizer = SGD(
                params=param_groups,
                param_scheme=param_scheme,
                lr=self.args.lr,
                momentum=self.args.momentum,
                dampening=0,
                weight_decay=self.args.weight_decay,
                nesterov=False,
                num_micro_batches=self.args.num_micro_batches,
            )
        elif optimizer_name == 'adam':
            optimizer = Adam(
                params=param_groups,
                param_scheme=param_scheme,
                lr=self.args.lr,
                num_micro_batches=self.args.num_micro_batches,
            )
        elif optimizer_name == 'adamw':
            optimizer = AdamW(
                params=param_groups,
                param_scheme=param_scheme,
                lr=self.args.lr,
                num_micro_batches=self.args.num_micro_batches,
            )
        else:
            raise NotImplementedError()
        self.optimizer = optimizer
        self.module.set_optimizer(optimizer)
        print(f'Engine configuring optimizer ended')
    
    def _configure_param_groups(self, named_parameters, param_scheme, name_groups):
        # divide quantized parameters into group
        qp_list = []
        qname_list = []
        unqp_list = []
        unqname_list = []
        for name, p in named_parameters:
            if param_scheme.is_quantifiable(name):
                qp_list.append(p)
                qname_list.append(name)
            else:
                unqp_list.append(p)
                unqname_list.append(name)
        
        def partition_parameters_by_names(name_list, p_list, name_groups):
            param_groups = []
            if name_groups is None:
                return [{
                    'names': name_list,
                    'params': p_list
                }]
            for ng in name_groups:
                pg = {
                    'params': [],
                    'names': [],
                }
                for key in ng:
                    if key != 'params' and key != 'names':
                        pg[key] = ng[key]
                for name in ng['names']:
                    for n, p in zip(name_list, p_list):
                        if name == n:
                            pg['names'].append(n)
                            pg['params'].append(p)
                if len(pg['params']) > 0:
                    param_groups.append(pg)
            return param_groups

        # organize param groups 
        assert len(qp_list) > 0
        param_groups = partition_parameters_by_names(qname_list, qp_list, name_groups)
        for pg in param_groups:
            pg['quantifiable'] = True
        if len(unqp_list) > 0:
            unq_param_groups = partition_parameters_by_names(unqname_list, unqp_list, name_groups)
            for pg in unq_param_groups:
                pg['quantifiable'] = False
            param_groups += unq_param_groups
        
        print(f'param group number: {len(param_groups)}')
        return param_groups

    def _configure_recorder(self):
        args = self.args
        GOR.init(
            print_freq=config.debug_GOR_freq,
            raport_path=os.path.join(args.workspace, "optimization_recorder.json"),
            param_names=self.module.q_names,
        )

    def forward(self, input, *args, **kwargs):
        return self.module.model(input, *args, **kwargs)

    def train(self):
        config.training = True
        self.module.train()
        return self

    def eval(self):
        config.training = False
        self.module.eval()
        return self

    # not used
    def state_dict(self):
        ret = {
            'model': self.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        return ret

    def load_state_dict(self, state_dict):
        self.module.load_state_dict(state_dict['model'])
        self.optimizer.param_scheme = self.module.scheme
        self.optimizer.load_state_dict(state_dict['optimizer']) # TODO