import warnings
from collections import defaultdict
import torch
import pickle
import os

from lbmqt.conf import config
from lbmqt.functional import compute_outlier, get_dynamic_map, dequantize_data_general, \
    group_tensor, compute_absmax_per_group, quantize_data_general, \
    compute_absmax_before_grouping, quantize_and_pack_linear_without_rounding, recon_grouped_tensor


class QScheme(object):
    
    def __init__(self, name, param_names, bits, group_size, enable: bool, range_mode=None, num_mode=None, stochastic=True, simulate=config.simulate) -> None:
        self.name = name

        name_groups = list(param_names)
        if len(name_groups) == 0:
            raise ValueError("qscheme got an empty parameter name list")
        if not isinstance(name_groups[0], dict):
            name_groups = [{
                'param_names': name_groups
            }]

        if not (isinstance(bits, int) and (bits > 0)):
            raise ValueError(f'bits must be positive integer, but found {bits}')
        if not (isinstance(group_size, int) and (group_size > 0)):
            raise ValueError(f'group_size must be positive integer, but found {group_size}')
        if range_mode is None:
            range_mode = config.default_range_mode
        if range_mode not in config.qt_range_modes:
            raise ValueError('unsupported range_mode: {}'.format(range_mode))
        if num_mode is None:
            num_mode = config.default_numerical_mode
        if num_mode not in config.qt_numerical_modes:
            raise ValueError('unsupported num_mode: {}'.format(num_mode))
        self.default_simulate = simulate

        defaults = dict(bits=bits, group_size=group_size, enable=enable,
                        range_mode=range_mode, num_mode=num_mode, stochastic=stochastic, qmap=dict())
        self.defaults = defaults
        self.state = defaultdict(dict)
        self.param_name_groups = []

        for name_group in name_groups:
            self.add_name_group(name_group)

    def __repr__(self) -> str:
        ret = f'QScheme {self.name}:\n'
        for idx, ng in enumerate(self.param_name_groups):
            group_line = f'\tgroup {idx}: '
            for key, val in ng.items():
                if key != 'param_names':
                    group_line += f'{key}={val}, '
            group_line += '\n'
            ret += group_line
            for p_name in ng['param_names']:
                p_line = f'\t\t{p_name}: '
                for key, val in self.state[p_name].items():
                    if key != 'idx2group':
                        p_line += f'{key}={val}, '
                p_line += '\n'
                ret += p_line
        return ret

    def add_name_group(self, name_group):
        r"""Add a name group to the :class:`QScheme` s `param_groups`.
        This can be useful when fine tuning a pre-trained network as frozen layers can be made
        trainable and added to the :class:`QScheme` as training progresses.
        Args:
            name_group (dict): Specifies what Tensors should be quantized along with group
                specific quantization options.
        """
        assert isinstance(name_group, dict), "name group must be a dict"

        param_names = name_group['param_names']
        if isinstance(param_names, str):
            name_group['param_names'] = [param_names]
        elif isinstance(param_names, set):
            raise TypeError('quantized parameters need to be organized in ordered collections, but '
                            'the ordering of tensors in sets will change between runs. Please use a list instead.')
        else:
            name_group['param_names'] = list(param_names)

        for param_name in name_group['param_names']:
            if not isinstance(param_name, str):
                raise TypeError('one of the param_name type is ' + torch.typename(param_name))
        
        # add group attributes
        for name, default in self.defaults.items():
            name_group.setdefault(name, default)
        # add other attributes specially
        b = name_group['bits']
        gp_sz = name_group['group_size']
        num_mode = name_group['num_mode']
        stochastic = name_group['stochastic']
        if num_mode == 'nonlinear':
            name_group['qmap'] = get_dynamic_map(bits=b)
        else:
            name_group['qmap'] = {'dynamic': None, 'udynamic': None}
        if b <= 8 and (8 % b == 0) and (gp_sz % int(32 / b) == 0):
            name_group['simulate'] = self.default_simulate
        elif 8 < b and b <= 16 and num_mode == 'nonlinear':
            name_group['simulate'] = self.default_simulate
        else:
            name_group['simulate'] = True
        # additional, to keep deternimistic in linear mode
        if not stochastic and num_mode == 'linear':
            name_group['simulate'] = True

        param_names = name_group['param_names']
        if len(param_names) != len(set(param_names)):
            warnings.warn('optimizer contains a parameter group with duplicate parameters; '
                          'in future, this will cause an error; ', stacklevel=3)

        param_set = set()
        for group in self.param_name_groups:
            param_set.update(set(group['param_names']))

        if not param_set.isdisjoint(set(name_group['param_names'])):
            raise ValueError('some parameters appear in more than one parameter group')

        self.param_name_groups.append(name_group)
        for name in param_names:
            self.state[name]['idx2group'] = name_group

        # print group info
        simulate = name_group['simulate']
        enable = name_group['enable']
        print(f'{self.name} scheme: enable={enable}, simulate={simulate}, bits={b}, group_size={gp_sz}, num_mode={num_mode}, stochastic={stochastic}')

    def get_quantization_config(self, p_name):
        group = self.state[p_name].get('idx2group', None)
        if group is not None:
            config = {}
            config['enable'] = group['enable']
            config['bits'] = group['bits']
            config['group_size'] = group['group_size']
            config['shape'] = self.state[p_name].get('shape', None)
            config['scale'] = self.state[p_name].get('scale', None)
            config['min'] = self.state[p_name].get('min', None)
            config['absmax'] = self.state[p_name].get('absmax', None)
            config['range_mode'] = group['range_mode']
            config['num_mode'] = group['num_mode']
            config['qmap'] = group['qmap']
            config['simulate'] = group['simulate']
            config['stochastic'] = group['stochastic']
        else:
            config = None
        return config

    def is_quantifiable(self, p_name: str) -> bool:
        state = self.state.get(p_name, None)
        return state is not None

    def is_quantized(self, p_name: str) -> bool:
        if not self.is_quantifiable(p_name):
            return False
        enable = self.get_state('enable', p_name)
        return enable

    def get_state(self, name, p_name):
        val = self.state[p_name].get(name, None)
        if val is not None:
            return val
        else:
            group = self.state[p_name].get('idx2group', None)
            val = group.get(name, None)
            if val is not None:
                return val
            else:
                raise ValueError('param name {} does not have state {} in qscheme'.format(
                    p_name, name
                ))

    def set_state(self, name, p_name, val):
        group = self.state[p_name].get('idx2group', None)
        if group is None:
            raise ValueError('get None param_group when set state in qscheme')
        # if val.shape[1] != 1:
        #     raise ValueError('last dim of val should be 1, but found: {}'.format(val.shape[1]))
        self.state[p_name][name] = val

    def reset_state(self, p_name, key):
        self.state[p_name][key] = None

    def reset_quantization_state(self, p_name):
        if not self.is_quantized(p_name):
            return
        self.reset_state(p_name, 'scale')
        self.reset_state(p_name, 'min')
        self.reset_state(p_name, 'absmax')

    def quantize_data(self, input, p_name, signed=True, dilate=1):
        if not self.is_quantized(p_name):
            return input
        q_config = self.get_quantization_config(p_name)
        q_config['signed'] = signed
        q_config['dilate'] = dilate
        if q_config['shape'] is None:
            self.set_state('shape', p_name, input.shape)
            q_config['shape'] = input.shape
        if q_config['absmax'] is None:
            absmax = compute_absmax_before_grouping(input, q_config['group_size'], q_config['dilate'])
            self.set_state('absmax', p_name, absmax)
            q_config['absmax'] = absmax
        q_input = quantize_data_general(input, q_config)
        return q_input

    def dequantize_data(self, q_input, p_name, signed=True):
        if not self.is_quantized(p_name):
            return q_input
        q_config = self.get_quantization_config(p_name)
        q_config['signed'] = signed
        input = dequantize_data_general(q_input, q_config)
        return input

    def calculate_bin_change_num(self, p_name, update, signed=True):
        r'''only for linear absmax quantization for parameters, do not call this function in other case
        '''
        q_config = self.get_quantization_config(p_name)
        b = q_config['bits']
        input_groups = group_tensor(update, gp_sz=q_config['group_size'])
        q_update = quantize_and_pack_linear_without_rounding(input_groups, q_config['absmax'], q_config['bits'], signed)
        rounding_proba = q_update.abs()
        rounding_proba = recon_grouped_tensor(rounding_proba, q_config['shape'])
        
        hist_list = []
        accumulated_num, num = 0, 0
        for factor in [1/8, 1/4, 1/2, 1, 2, 4, 8]:
            num = (rounding_proba <= factor).sum().item() - accumulated_num
            accumulated_num += num
            hist_list.append(num)
        hist_list.append(rounding_proba.numel() - accumulated_num)
        # print(hist_list, rounding_proba.numel())
        hist = torch.tensor(hist_list) / rounding_proba.numel()
        # print(hist)
        # hist = torch.histc(rounding_proba.view(-1), bins=bin_num, min=0, max=B) / new_unquantized.numel()
        return hist


class QGradScheme(QScheme):
    def __init__(self, name, param_names, bits, group_size, enable: bool, range_mode=None, num_mode=None, stochastic=True, simulate=config.simulate) -> None:
        super(QGradScheme, self).__init__(name, param_names, bits, group_size, enable, range_mode, num_mode, stochastic, simulate)
        assert self.defaults['range_mode'] == 'absmax', 'QGradScheme only support absmax mode'
        for png in self.param_name_groups:
            for p_name in png['param_names']:
                self.state[p_name]['warmup'] = True
                self.state[p_name]['outlier'] = None

    def turn_off_warmup(self, p_name):
        if not self.is_quantized(p_name):
            return
        self.state[p_name]['warmup'] = False

    def midst_process_inside_microbatch(self, input, p_name):
        if not self.is_quantized(p_name):
            return
        if self.state[p_name]['warmup']:
            super(QGradScheme, self).reset_quantization_state(p_name)
        q_config = self.get_quantization_config(p_name)
        input_groups = group_tensor(input, gp_sz=q_config['group_size'])
        absmax = compute_absmax_per_group(input_groups, scale=1)
        self.set_state('absmax_next', p_name, absmax)

    def reset_quantization_state(self, p_name):
        if not self.is_quantized(p_name):
            return
        super(QGradScheme, self).reset_quantization_state(p_name)
        if not self.state[p_name]['warmup']:
            absmax_next = self.get_state('absmax_next', p_name)
            self.set_state('absmax', p_name, absmax_next)

    def quantize_data(self, input, p_name, signed=True):
        if not self.is_quantized(p_name):
            return input
        q_input = super(QGradScheme, self).quantize_data(input, p_name, signed)
        if config.save_grad_outlier:
            q_config = self.get_quantization_config(p_name)
            self.state[p_name]['outlier'] = compute_outlier(input, q_config)
        return q_input

    def dequantize_data(self, q_input, p_name, signed=True):
        if not self.is_quantized(p_name):
            return q_input
        input = super(QGradScheme, self).dequantize_data(q_input, p_name, signed)
        if config.save_grad_outlier:
            outlier = self.state[p_name]['outlier']
            if outlier is not None:
                input += outlier.to_dense()
                self.state[p_name]['outlier'] = None
        return input 
