import gc
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
)
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)

from .AnyPrecisionLinear_train_whole import (
    AnyPrecisionLinear_train_whole3456
)
from any_precision.analyzer.analyzer import get_analyzer


def replace_module_by_name(layer, module_name, new_module):
    levels = module_name.split('.')
    module = layer
    for level in levels[:-1]:
        module = getattr(module, level) if not level.isdigit() else module[int(level)]
    setattr(module, levels[-1], new_module)

class AnyPrecisionForCausalLM_train_whole3456(nn.Module):
    def __init__(
            self,
            model_path,
            config,
            precisions=None,
            torch_dtype=torch.float16,
            fuse_layers=False,
            trust_remote_code=True,
            th_init_dict=None,
            max_mem_dict={},
    ):
        super().__init__()

        self.config = config

        self.supported_bits = list(range(self.config.anyprec['seed_precision'],
                                         self.config.anyprec['parent_precision'] + 1))
        if precisions is None:
            self.precisions = self.supported_bits
        else:
            assert len(precisions) == len(set(precisions)), "Precisions must be unique"
            assert all(bit in self.supported_bits for bit in precisions), \
                f"Supported bits {precisions} must be a subset of model supported bits {self.supported_bits}"
            self.precisions = precisions

        self.precision = max(self.precisions)

        with init_empty_weights():
            self.model = AutoModelForCausalLM.from_config(
                config=config,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
                attn_implementation="flash_attention_2",
            )

        self.analyzer = get_analyzer(self.model)

        self.ap_linears = []
        # Replace to AnyPrecisionLinear layers
        self._load_quantized_modules(dtype=torch_dtype, th_init_dict=th_init_dict, max_mem_dict=max_mem_dict)

        self.tie_weights()

        device_map = {key: 'cpu' for key in self.model.state_dict().keys()}

        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
            self.model,
            checkpoint=model_path,
            device_map=device_map,
            no_split_module_classes=[self.layer_type],
            dtype=torch_dtype,
        )

        # Dispath to devices
        if fuse_layers:
            self.fuse_layers()

        self.evaled = False

        self.prune_precisions()

    def forward(self, *args, **kwargs):
        prev_precision = self.precision
        precision = None
        if 'precision' in kwargs:
            precision = kwargs.pop('precision')
            self.set_precision(precision)

        results = self.model.forward(*args, **kwargs)
        if precision is not None:
            self.set_precision(prev_precision)
        return results

    def generate(self, *args, **kwargs):
        precision = None
        if 'precision' in kwargs:
            prev_precision = self.precision
            precision = kwargs.pop('precision')
            self.set_precision(precision)
        else:
            prev_precision = self.precision

        with torch.inference_mode():
            results = self.model.generate(*args, **kwargs)
        if precision is not None:
            self.set_precision(prev_precision)
        return results

    @staticmethod
    def _load_config(
            model_path,
            trust_remote_code=True,
    ):
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
        return config

    @classmethod
    def from_quantized(
            cls,
            quant_model_path,
            trust_remote_code=True,
            fuse_layers=False,
            precisions=None,
            th_init_dict=None,
            torch_dtype=torch.float16,
            max_mem_dict={},
    ):
        config = cls._load_config(quant_model_path, trust_remote_code)

        ap_model = cls(
            model_path=quant_model_path,
            precisions=precisions,
            config=config,
            fuse_layers=fuse_layers,
            trust_remote_code=trust_remote_code,
            th_init_dict=th_init_dict,
            max_mem_dict=max_mem_dict,
            torch_dtype=torch_dtype,
        )

        return ap_model

    def _load_quantized_modules(self, dtype=torch.float16, th_init_dict=None, max_mem_dict={}):
        # Get blocks of model
        layers = self.analyzer.get_layers()

        for i, layer in enumerate(tqdm(layers, desc="Loading AP Layers")):
            # Get every linear layer in a block
            named_linears = self.analyzer.get_modules(layer)

            # Replace nn.Linear with AnyPrecisionLinear
            for name, module in named_linears.items():
                real_name = name.split(".")[-1]
                th_init = None
                if th_init_dict is not None:
                    if (i,real_name) in th_init_dict.keys():
                        th_init = th_init_dict[(i,real_name)]
                maxmem = max_mem_dict[(i,real_name)]

                aplinclass = AnyPrecisionLinear_train_whole3456
                # if type(th_init) is list:
                #     if self.precisions == [3,4,5,6]:
                #         aplinclass = AnyPrecisionLinear_train_all_3456
                #     elif self.precisions == [2,3,4,5]:
                #         aplinclass = AnyPrecisionLinear_train_all_2345
                #     else:
                #         raise RuntimeError(f"Unknown precision list {self.precisions}")


                wqlinear = aplinclass(
                    module.in_features, module.out_features,
                    self.supported_bits,
                    bias=module.bias is not None,
                    precisions=self.precisions,
                    device=module.weight.device,
                    dtype=dtype,
                    th_init=th_init,
                    maxmem=maxmem,
                )
                self.ap_linears.append(wqlinear)
                replace_module_by_name(layer, name, wqlinear)

            torch.cuda.empty_cache()
            gc.collect()

    def prune_precisions(self):
        for ap_linear in self.ap_linears:
            ap_linear.prune_precisions()

        torch.cuda.empty_cache()
        gc.collect()

    def set_precision(self, precision):
        for ap_linear in self.ap_linears:
            ap_linear.set_precision(precision)
        self.precision = precision

    def set_precision_dual(self, b_l, b_h):
        for ap_linear in self.ap_linears:
            ap_linear.set_precision_dual(b_l, b_h)

    def tie_weights(self):
        if hasattr(self.model, "tie_weights"):
            self.model.tie_weights()

    def create_th(self):
        for ap_linear in self.ap_linears:
            ap_linear.create_th()

    def get_model_layers(self):
        module = self.model
        for attrib_name in self.config.anyprec['arch_config']['model_name'].split('.'):
            module = getattr(module, attrib_name)
        return getattr(module, self.config.anyprec['arch_config']['layers_name'])

    def fuse_layers(self):
        if 'fuse_target_layers' not in self.model_config:
            raise NotImplementedError("This model does not support layer fusion")
        # TODO implement layer fusion
        pass

    def eval(self):
        if not self.evaled:
            super().eval()
            self.evaled = True
        return self

    @property
    def layer_type(self):
        for layer in self.get_model_layers():
            layer_class_name = layer.__class__.__name__
            if layer_class_name.endswith("DecoderLayer"):
                return layer_class_name
        return None

    @property
    def device(self):
        return self.model.device