import gc
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
)
from accelerate.big_modeling import (
    init_empty_weights,
    load_checkpoint_and_dispatch,
)

from .AnyPrecisionLinear_3456 import AnyPrecisionLinear_3456
from any_precision.analyzer.analyzer import get_analyzer


def replace_module_by_name(layer, module_name, new_module):
    levels = module_name.split('.')
    module = layer
    for level in levels[:-1]:
        module = getattr(module, level) if not level.isdigit() else module[int(level)]
    setattr(module, levels[-1], new_module)


class AnyPrecisionForCausalLM_3456(nn.Module):
    def __init__(
            self,
            model_path,
            config,
            precisions=None,
            torch_dtype=torch.float16,
            fuse_layers=False,
            trust_remote_code=True,
            mode="orig",
            prefill_as_decode=False,
            model="llama3",
            path_dict={},
    ):
        super().__init__()

        self.config = config

        self.supported_bits = list(range(self.config.anyprec['seed_precision'],
                                         self.config.anyprec['parent_precision'] + 1))
        if precisions is None:
            self.precisions = self.supported_bits
        else:
            assert len(precisions) == len(set(precisions)), "Precisions must be unique"
            assert all(bit in self.supported_bits for bit in precisions), \
                f"Supported bits {precisions} must be a subset of model supported bits {self.supported_bits}"
            self.precisions = precisions

        self.precision = max(self.precisions)

        with init_empty_weights():
            self.model = AutoModelForCausalLM.from_config(
                    config=config,
                    torch_dtype=torch_dtype,
                    trust_remote_code=trust_remote_code,
                    attn_implementation="flash_attention_2",
                )

        self.analyzer = get_analyzer(self.model)

        self.ap_linears = []
        # Replace to AnyPrecisionLinear layers
        self._load_quantized_modules(mode=mode, prefill_as_decode=prefill_as_decode, model=model, dtype=torch_dtype, path_dict=path_dict)

        self.model_name = model

        self.tie_weights()

        device_map = {key: 'cpu' for key in self.model.state_dict().keys()}

        # loads the weights into modules and distributes
        # across available devices automatically
        load_checkpoint_and_dispatch(
            self.model,
            checkpoint=model_path,
            device_map=device_map,
            no_split_module_classes=[self.layer_type],
            dtype=torch_dtype,
        )

        # Dispath to devices
        if fuse_layers:
            self.fuse_layers()

        self.evaled = False

        self.prune_precisions()

    def forward(self, *args, **kwargs):
        prev_precision = self.precision
        precision = None
        if 'precision' in kwargs:
            precision = kwargs.pop('precision')
            self.set_precision(precision)

        results = self.model.forward(*args, **kwargs)
        if precision is not None:
            self.set_precision(prev_precision)
        return results

    def generate(self, *args, **kwargs):
        precision = None
        if 'precision' in kwargs:
            prev_precision = self.precision
            precision = kwargs.pop('precision')
            self.set_precision(precision)
        else:
            prev_precision = self.precision

        with torch.inference_mode():
            results = self.model.generate(*args, **kwargs)
        if precision is not None:
            self.set_precision(prev_precision)
        return results

    @staticmethod
    def _load_config(
            model_path,
            trust_remote_code=True,
    ):
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
        return config

    @classmethod
    def from_quantized(
            cls,
            quant_model_path,
            trust_remote_code=True,
            fuse_layers=False,
            precisions=None,
            mode="orig",
            prefill_as_decode=False,
            model="llama3",
            torch_dtype=torch.float16,
            path_dict={},
    ):
        config = cls._load_config(quant_model_path, trust_remote_code)

        ap_model = cls(
            model_path=quant_model_path,
            precisions=precisions,
            config=config,
            fuse_layers=fuse_layers,
            trust_remote_code=trust_remote_code,
            mode=mode,
            prefill_as_decode=prefill_as_decode,
            model=model,
            torch_dtype=torch_dtype,
            path_dict=path_dict
        )

        return ap_model

    def _load_quantized_modules(self, mode="orig", prefill_as_decode=False, model="llama3", dtype=torch.float16, path_dict={}):
        # Get blocks of model
        layers = self.analyzer.get_layers()

        assert mode in ["orig", "mqdecode", "jl", "decode", "oracle", "mq", "full", "random"]

        bits_path = "3456"
        if self.precisions == [2,3,4,5]:
            bits_path = "2345"
        if self.precisions == [2,3,4]:
            bits_path = "234"
        if self.precisions == [3,4,5]:
            bits_path = "345"

        if mode == "jl" or mode == "full":
            corr_arr_path = path_dict["corr_arr_path"]
            corr_arr = torch.load(corr_arr_path, weights_only=False)
            corr_dict = {}
            for (l, n, slope, inter, _, _, _) in corr_arr:
                corr_dict[(l,n)] = (slope, inter)

        max_mem_dict = {}
        if mode == "jl" or mode == "oracle" or mode == "full" or mode == "random" or mode == "mqdecode":
            max_mem_dict_path = path_dict["max_mem_dict_path"]
            max_mem_dict = torch.load(max_mem_dict_path, weights_only=False)
            print(f"Max_mem_dict: {str(list(max_mem_dict.values()))[:60]}...")

        real_i = 0
        for i, layer in enumerate(tqdm(layers, desc="Loading AP Layers")):
            # Get every linear layer in a block
            named_linears = self.analyzer.get_modules(layer)

            # Replace nn.Linear with AnyPrecisionLinear
            for name, module in named_linears.items():
                real_name = name.split(".")[-1]

                if (i, real_name) not in max_mem_dict.keys():
                    max_mem_dict[(i, real_name)] = self.precisions[-1]

                err_lin_param=None
                err_mode=None
                jl_path=None
                targ_path=None

                if mode == "decode":
                    err_mode="decode"
                elif mode == "mqdecode" or mode == "mq":
                    err_mode=mode
                    targ_path = path_dict["targ_path_fn"](i, real_name)

                elif mode == "oracle":
                    err_mode="oracle"
                
                elif mode == "random":
                    err_mode="random"

                elif mode == "jl" or mode == "full":
                    err_mode="full"
                    if mode == "jl":
                        if real_name == "q_proj" or real_name == "k_proj" or real_name == "v_proj" or real_name == "qkv_proj":
                            if i > 0:
                                err_mode = "prev"
                        elif real_name == "gate_proj" or real_name == "up_proj" or real_name == "gate_up_proj":
                            err_mode = "intra"
                        elif real_name == "o_proj" or real_name == "down_proj":
                            pass
                        else:
                            raise RuntimeError(f"Unknown module: {real_name}")

                    if (i,real_name) in corr_dict.keys():
                        jl_path="non_empty_string"
                        slope, inter = corr_dict[(i,real_name)]
                        if err_mode == "full":
                            err_mode="lin"
                        elif err_mode == "prev":
                            err_mode="prevlin"
                        elif err_mode == "intra":
                            err_mode="intralin"
                        err_lin_param = (slope, inter)
                    else:
                        jl_path = path_dict["jl_path_fn"](i, real_name)

                if mode == "jl" or mode == "oracle" or mode == "full" or mode == "random":
                    targ_path = path_dict["targ_path_fn"](i, real_name)
                    

                wqlinear = AnyPrecisionLinear_3456(
                    module.in_features, module.out_features,
                    self.supported_bits,
                    bias=module.bias is not None,
                    precisions=self.precisions,
                    dtype=dtype,
                    device=module.weight.device,
                    jl_path=jl_path,
                    err_mode=err_mode,
                    err_lin_param=err_lin_param,
                    targ_path=targ_path,
                    my_name=real_name,
                    my_layer=i,
                    prefill_as_decode=prefill_as_decode,
                    maxmem = max_mem_dict[(i, real_name)]
                )
                self.ap_linears.append(wqlinear)
                replace_module_by_name(layer, name, wqlinear)

                real_i += 1

            torch.cuda.empty_cache()
            gc.collect()
    

    def get_effective_bits(self):
        total_bits = 0
        total_comps = 0
        for linear in self.ap_linears:
            total_comps_temp = 0
            for bits in linear.comp_count.keys():
                comp = linear.comp_count[bits]
                total_bits += (comp * bits)*(linear.in_features*linear.out_features)
                total_comps_temp += comp
            if total_comps == 0:
                total_comps = total_comps_temp
        total_params = 0
        for linear in self.ap_linears:
            total_params += (linear.in_features*linear.out_features)

        return total_bits/(total_params*total_comps) if total_comps > 0 else 0
    
    def clear_comp_count(self):
        for linear in self.ap_linears:
            for bits in linear.comp_count.keys():
                linear.comp_count[bits] = 0

    def prune_precisions(self):
        for ap_linear in self.ap_linears:
            ap_linear.prune_precisions()

        torch.cuda.empty_cache()
        gc.collect()

    def set_precision(self, precision):
        for ap_linear in self.ap_linears:
            ap_linear.set_precision(precision)
        self.precision = precision

    def tie_weights(self):
        if hasattr(self.model, "tie_weights"):
            self.model.tie_weights()

    def get_model_layers(self):
        module = self.model
        for attrib_name in self.config.anyprec['arch_config']['model_name'].split('.'):
            module = getattr(module, attrib_name)
        return getattr(module, self.config.anyprec['arch_config']['layers_name'])

    def fuse_layers(self):
        if 'fuse_target_layers' not in self.model_config:
            raise NotImplementedError("This model does not support layer fusion")
        # TODO implement layer fusion
        pass

    def setMotherLayer(self):
        layers = self.analyzer.get_layers()
        for i, layer in enumerate(layers):
            # Get every linear layer in a block
            named_linears = self.analyzer.get_modules(layer)

            # Replace nn.Linear with AnyPrecisionLinear
            for name, module in named_linears.items():
                real_name = name.split(".")[-1]
                if real_name == "q_proj" or real_name == "k_proj" or real_name == "v_proj" or real_name == "qkv_proj":
                    if i > 0:
                        module.mother_layer = layers[i-1]
                    module.mother_ln = layer.input_layernorm
                elif real_name == "gate_proj" or real_name == "up_proj" or real_name == "gate_up_proj":
                    module.mother_layer = layer
                    module.mother_ln = layer.post_attention_layernorm
                elif real_name == "o_proj" or real_name == "down_proj":
                    pass
                else:
                    raise RuntimeError(f"Set Mother Layer Failed: Unknown module {name}")

    def eval(self):
        if not self.evaled:
            super().eval()
            self.evaled = True
        return self

    @property
    def layer_type(self):
        for layer in self.get_model_layers():
            layer_class_name = layer.__class__.__name__
            if layer_class_name.endswith("DecoderLayer"):
                return layer_class_name
        return None

    @property
    def device(self):
        return self.model.device
