import os
import torch
import torch.nn as nn

from tqdm import tqdm
from transformers import PreTrainedModel
from typing import (
    Dict,
    List,
    Tuple,
)

from .model_utils import find_modules
from ..models import MoLoSLlamaMLP
from ..models.modeling_svd_llm import SVD_LlamaMLP


@torch.no_grad()
def old_profile(
    model: PreTrainedModel,
    calibration_dataset: List[Dict[str, torch.Tensor]],
    device: str,
    eps: float = 1e-6,
) -> Dict[int, Dict[str, torch.Tensor]]:
    """ Profile the model to get the scaling diagonal matrices.
        This function is same as the `profile_svdllm_low_resource()` function in SVD-LLM repository.

    Args:
        model (PreTrainedModel): The model to be profiled.
        calibration_dataset (List[Dict[str, torch.Tensor]]): The calibration dataset.
        device (str): The device to be used.
        eps (float): A small value to stabilize the computation. Defaults to 1e-6.

    Returns:
        Dict[int, Dict[str, torch.Tensor]]: The scaling diagonal matrix for each layer.
    """

    ## Get `attention_mask` and `position_ids` from the first layer.
    model.model.embed_tokens = model.model.embed_tokens.to(device=device)
    model.model.norm = model.model.norm.to(device=device)
    model.model.rotary_emb = model.model.rotary_emb.to(device=device)

    layers = model.model.layers

    layers[0] = layers[0].to(device=device)

    cache = {
        'i': 0,
        'attention_mask': None,
        'position_ids': None,
    }
    dtype = next(iter(model.parameters())).dtype
    inputs = torch.zeros(
        size=(
            len(calibration_dataset),
            model.sequence_length,
            model.config.hidden_size,
        ),
        dtype=dtype,
        device=device,
    )

    # Build a Catcher to catch `attention_mask` and `position_ids` from the first layer.
    class Catcher(nn.Module):

        def __init__(
            self,
            module: nn.Module,
        ):
            super().__init__()
            self.module = module

        def forward(
            self,
            input: torch.Tensor,
            **kwargs: Dict[str, torch.Tensor],
        ):
            inputs[cache['i']] = input.cpu()
            cache['i'] += 1

            if cache['attention_mask'] is None:
                cache['attention_mask'] = kwargs['attention_mask'].cpu()
                cache['position_ids'] = kwargs['position_ids'].cpu()
            else:
                cache['attention_mask'] = torch.cat(
                    tensors=(
                        cache['attention_mask'],
                        kwargs['attention_mask'].cpu(),
                    ),
                    dim=0,
                )
                cache['position_ids'] = torch.cat(
                    tensors=(
                        cache['position_ids'],
                        kwargs['position_ids'].cpu(),
                    ),
                    dim=0,
                )

            # Raise ValueError to stop the forward pass.
            raise ValueError

    layers[0] = Catcher(module=layers[0])
    for batch in calibration_dataset:
        try:
            batch = {k: v.to(device=device) for k, v in batch.items()}
            model(
                **batch,
                output_attentions=True,
            )
        # Due to the ValueError raised in the Catcher, the forward pass will stop at the first layer.
        except ValueError:
            pass

    # Recover the first layer to its original state.
    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    model.model.norm = model.model.norm.cpu()
    model.model.rotary_emb = model.model.rotary_emb.cpu()

    torch.cuda.empty_cache()

    outputs = torch.zeros_like(input=inputs)
    attention_masks = cache['attention_mask']
    position_ids = cache['position_ids']
    ## -----

    print('Start profiling layers to get the scaling diagonal matrices.')

    layers_profiling = {}
    for i in tqdm(
            iterable=range(len(layers)),
            desc='[Profiling Layers]',
            dynamic_ncols=True,
    ):
        layer_profiling = {}

        layer = layers[i].to(device=device)
        subset = find_modules(module=layer)

        # Calculate `raw_scaling_diag_matrix`.
        def hook(
            module: nn.Module,
            input: torch.Tensor,
            output: torch.Tensor,
        ) -> None:
            _input = input[0].detach().float()

            adds = torch.matmul(
                input=_input.transpose(
                    dim0=1,
                    dim1=2,
                ),
                other=_input,
            )
            adds_sum = torch.sum(
                input=adds,
                dim=0,
            )

            module.scaling_diag_matrix += adds_sum

            del _input, adds, adds_sum, output
            torch.cuda.empty_cache()

        # Register the hooks.
        handles = []
        for name in subset:
            subset[name].scaling_diag_matrix = 0
            handles.append(subset[name].register_forward_hook(hook=hook))

        # Calculate `raw_scaling_diag_matrix` for the current layer.
        for j in range(inputs.shape[0]):
            outputs[j] = layer(
                inputs[j].unsqueeze(dim=0),
                attention_mask=\
                    attention_masks[j].unsqueeze(dim=0).to(device=device),
                position_ids=\
                    position_ids[j].unsqueeze(dim=0).to(device=device)
            )[0]

        ## Release GPU memory.
        for handle in handles:
            handle.remove()

        for name in subset:
            subset[name].scaling_diag_matrix = \
                subset[name].scaling_diag_matrix.cpu()

        layer = layer.cpu()
        torch.cuda.empty_cache()
        ## -----

        ## Do Cholesky decomposition to get `scaling_diag_matrix`.
        for name in subset:
            raw_scaling_diag_matrix = \
                subset[name].scaling_diag_matrix.double().to(device=device)

            try:
                scaling_diag_matrix = \
                    torch.linalg.cholesky(input=raw_scaling_diag_matrix)
            except Exception as e:
                print(
                    'Warning: `scaling_diag_matrix` is NOT a positive-definite matrix.'
                )

                eigenvalues = \
                    torch.linalg.eigvalsh(input=raw_scaling_diag_matrix)

                raw_scaling_diag_matrix += (
                    (-eigenvalues[0] + eps) * torch.eye(
                        n=raw_scaling_diag_matrix.shape[0]).to(device=device))

                scaling_diag_matrix = \
                    torch.linalg.cholesky(input=raw_scaling_diag_matrix)

                # Release GPU memory.
                eigenvalues = None
                del eigenvalues

            layer_profiling[name] = scaling_diag_matrix.cpu()

            # Release GPU memory.
            raw_scaling_diag_matrix = None
            scaling_diag_matrix = None
            subset[name].raw_scaling_diag_matrix = None
            del raw_scaling_diag_matrix
            del scaling_diag_matrix
            del subset[name].raw_scaling_diag_matrix
            torch.cuda.empty_cache()

        layers_profiling[i] = layer_profiling
        inputs = outputs

        # Release GPU memory.
        layers[i] = layer.cpu()
        torch.cuda.empty_cache()
        ## -----

    return layers_profiling


@torch.no_grad()
def profile(
    base_model: PreTrainedModel,
    base_model_name: str,
    calibration_dataset: List[Dict[str, torch.Tensor]],
    calibration_dataset_name: str,
    cache_dir: str,
    local_rank: int,
    device: str,
    eps: float = 1e-6,
    only_return_name: bool = False,
) -> Tuple[Dict[int, Dict[str, torch.Tensor]], str]:
    """ Profile the model to get the scaling diagonal matrices.
        The execution time bottleneck of this function is NOT the data transfer between CPU and GPU.

    Args:
        base_model (PreTrainedModel): The base model to be profiled.
        base_model_name (str): The name of the base model.
        calibration_dataset (List[Dict[str, torch.Tensor]]): The calibration dataset.
        calibration_dataset_name (str): The name of the calibration dataset.
        cache_dir (str): The cache directory.
        local_rank (int): The local rank of the process.
        device (str): The device to be used.
        eps (float): A small value to stabilize the computation. Defaults to 1e-6.
        only_return_name (bool): If True, only return the name of the scaling diagonal matrix. Defaults to False.

    Returns:
        Tuple[Dict[int, Dict[str, torch.Tensor]], str]: The scaling diagonal matrix and its name.
    """

    layers_profiling_name = f'{base_model_name}_{calibration_dataset_name}'
    cache_path = os.path.join(
        cache_dir,
        f'{layers_profiling_name}.pt',
    )

    # Try to load the `layers_profiling` from the cache.
    if os.path.exists(path=cache_path):
        layers_profiling = None

        if not only_return_name:
            layers_profiling = torch.load(
                f=cache_path,
                weights_only=True,
            )

        return (
            layers_profiling,
            layers_profiling_name,
        )

    ## Get `attention_mask` and `position_ids` from the first layer.
    base_model.model.embed_tokens = \
        base_model.model.embed_tokens.to(device=device)
    base_model.model.norm = base_model.model.norm.to(device=device)
    base_model.model.rotary_emb = base_model.model.rotary_emb.to(device=device)

    layers = base_model.model.layers

    layers[0] = layers[0].to(device=device)

    cache = {
        'i': 0,
        'attention_mask': None,
        'position_ids': None,
    }
    dtype = next(iter(base_model.parameters())).dtype
    inputs = torch.zeros(
        size=(
            len(calibration_dataset),
            base_model.sequence_length,
            base_model.config.hidden_size,
        ),
        dtype=dtype,
        device=device,
    )

    # Build a Catcher to catch `attention_mask` and `position_ids` from the first layer.
    class Catcher(nn.Module):

        def __init__(
            self,
            module: nn.Module,
        ):
            super().__init__()
            self.module = module

        def forward(
            self,
            input: torch.Tensor,
            **kwargs: Dict[str, torch.Tensor],
        ):
            inputs[cache['i']] = input.cpu()
            cache['i'] += 1

            if cache['attention_mask'] is None:
                cache['attention_mask'] = kwargs['attention_mask'].cpu()
                cache['position_ids'] = kwargs['position_ids'].cpu()
            else:
                cache['attention_mask'] = torch.cat(
                    tensors=(
                        cache['attention_mask'],
                        kwargs['attention_mask'].cpu(),
                    ),
                    dim=0,
                )
                cache['position_ids'] = torch.cat(
                    tensors=(
                        cache['position_ids'],
                        kwargs['position_ids'].cpu(),
                    ),
                    dim=0,
                )

            # Raise ValueError to stop the forward pass.
            raise ValueError

    layers[0] = Catcher(module=layers[0])
    for batch in calibration_dataset:
        try:
            batch = {k: v.to(device=device) for k, v in batch.items()}
            base_model(
                **batch,
                output_attentions=True,
            )
        # Due to the ValueError raised in the Catcher, the forward pass will stop at the first layer.
        except ValueError:
            pass

    # Recover the first layer to its original state.
    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()
    base_model.model.embed_tokens = base_model.model.embed_tokens.cpu()
    base_model.model.norm = base_model.model.norm.cpu()
    base_model.model.rotary_emb = base_model.model.rotary_emb.cpu()

    torch.cuda.empty_cache()

    outputs = torch.zeros_like(input=inputs)
    attention_masks = cache['attention_mask']
    position_ids = cache['position_ids']
    ## -----

    layers_profiling = {}
    for layer_idx in tqdm(
            iterable=range(len(layers)),
            desc='[Profiling Layers]',
            disable=True if local_rank != 0 else False,
            dynamic_ncols=True,
    ):
        layer_profiling = {}

        layer = layers[layer_idx].to(device=device)
        submodules = find_modules(module=layer)

        # Calculate `raw_scaling_diag_matrix`.
        def hook(
            module: nn.Module,
            input: torch.Tensor,
            output: torch.Tensor,
        ) -> None:
            _input = input[0].detach().float()

            accumulated_second_moment = torch.matmul(
                input=_input.transpose(
                    dim0=1,
                    dim1=2,
                ),
                other=_input,
            )

            # Dimension 0 is the batch size.
            # Sum over the batch size.
            accumulated_second_moment = torch.sum(
                input=accumulated_second_moment,
                dim=0,
            )

            module.scaling_diag_matrix += accumulated_second_moment

            del _input, accumulated_second_moment, output
            torch.cuda.empty_cache()

        # Register the hooks.
        handles = []
        for name in submodules.keys():
            submodules[name].scaling_diag_matrix = 0
            handles.append(submodules[name].register_forward_hook(hook=hook))

        # Calculate `raw_scaling_diag_matrix` for the current layer.
        for data_idx in range(inputs.shape[0]):
            outputs[data_idx] = layer(
                hidden_states=inputs[data_idx].unsqueeze(dim=0),
                attention_mask=attention_masks[data_idx]\
                    .unsqueeze(dim=0).to(device=device),
                position_ids=position_ids[data_idx]\
                    .unsqueeze(dim=0).to(device=device),
            )[0]

        ## Release GPU memory.
        for handle in handles:
            handle.remove()

        for name in submodules:
            submodules[name].scaling_diag_matrix = \
                submodules[name].scaling_diag_matrix.cpu()

        layer = layer.cpu()
        torch.cuda.empty_cache()
        ## -----

        ## Do Cholesky decomposition to get `scaling_diag_matrix`.
        for name in submodules.keys():
            raw_scaling_diag_matrix = \
                submodules[name].scaling_diag_matrix.double().to(device=device)

            try:
                scaling_diag_matrix = \
                    torch.linalg.cholesky(input=raw_scaling_diag_matrix)
            except Exception as exception:
                print(
                    'Warning: `scaling_diag_matrix` is NOT a positive-definite matrix.'
                )

                eigenvalues = \
                    torch.linalg.eigvalsh(input=raw_scaling_diag_matrix)

                raw_scaling_diag_matrix += (
                    (-eigenvalues[0] + eps) * torch.eye(
                        n=raw_scaling_diag_matrix.shape[0],
                        dtype=raw_scaling_diag_matrix.dtype,
                        device=device,
                    ))

                scaling_diag_matrix = \
                    torch.linalg.cholesky(input=raw_scaling_diag_matrix)

                # Release GPU memory.
                eigenvalues = None
                del eigenvalues

            layer_profiling[name] = scaling_diag_matrix.cpu()

            # Release GPU memory.
            raw_scaling_diag_matrix = None
            scaling_diag_matrix = None
            submodules[name].raw_scaling_diag_matrix = None
            del raw_scaling_diag_matrix
            del scaling_diag_matrix
            del submodules[name].raw_scaling_diag_matrix
            torch.cuda.empty_cache()

        layers_profiling[layer_idx] = layer_profiling
        inputs = outputs

        # Release GPU memory.
        layers[layer_idx] = layer.cpu()
        torch.cuda.empty_cache()
        ## -----

    if local_rank == 0:
        torch.save(
            obj=layers_profiling,
            f=cache_path,
        )

    return (
        layers_profiling,
        layers_profiling_name,
    )


@torch.no_grad()
def old_whiten(
    model: PreTrainedModel,
    layers_profiling: Dict[int, Dict[str, torch.Tensor]],
    ratio: float,
    device: str,
    eps: float = 1e-6,
) -> None:
    """ Do whitening and SVD.
        This function is revised from the `whitening()` function in SVD-LLM repository.
        The only difference between this function and the `whitening()` function is: This function does NOT compress the attention layers.

    Args:
        model (PreTrainedModel): The model to be whitened.
        layers_profiling (Dict[int, Dict[str, torch.Tensor]]): The scaling diagonal matrices.
        ratio (float): The expert ratio.
        device (str): The device to be used.
        eps (float): A small value to stabilize the computation. Defaults to 1e-6.
    """

    model.eval()

    layers = model.model.layers

    print('Start whitening and doing SVD.')

    for i in tqdm(
            iterable=range(len(layers)),
            desc='[Whitening & Doing SVD]',
            dynamic_ncols=True,
    ):
        layer = layers[i]
        subset = find_modules(module=layer)

        svd_mlp = SVD_LlamaMLP(
            hidden_size=layer.hidden_size,
            intermediate_size=model.config.intermediate_size,
            hidden_act=model.config.hidden_act,
            ratio=ratio,
        )

        for name in subset:
            W = subset[name].weight.data.float().to(device=device)
            dtype = W.dtype

            scaling_diag_matrix = layers_profiling[i][name].to(device=device)

            # Calculate the inverse of the scaling diagonal matrix.
            try:
                scaling_matrix_inv = torch.linalg.inv(A=scaling_diag_matrix)
            except Exception as e:
                print('Warning: `scaling_diag_matrix` is NOT full rank!')

                scaling_diag_matrix += (eps * torch.eye(
                    n=scaling_diag_matrix.shape[0]).to(device=device))
                scaling_matrix_inv = torch.linalg.inv(A=scaling_diag_matrix)

            scaling_diag_matrix = scaling_diag_matrix.float()
            scaling_matrix_inv = scaling_matrix_inv.float()

            ## Compute the SVD of the weight matrix.
            W_scale = torch.matmul(
                input=W,
                other=scaling_diag_matrix,
            )
            U, S, VT = torch.linalg.svd(
                A=W_scale,
                full_matrices=False,
            )
            ## -----

            ## Truncate the singular values and depending vectors.
            rank = int(W.shape[0] * W.shape[1] * ratio /
                       (W.shape[0] + W.shape[1]))

            truc_s = S[:rank]
            truc_u = U[:, :rank]
            truc_v = torch.matmul(
                input=VT[:rank, :],
                other=scaling_matrix_inv,
            )
            ## -----

            ## Convert the singular values to a diagonal matrix, and compute its square root.
            truc_sigma = torch.diag(input=truc_s)
            sqrtSigma = torch.sqrt(input=truc_sigma)
            ## -----

            ## Compute the final U and V matrices.
            svd_u = torch.matmul(
                input=truc_u,
                other=sqrtSigma,
            ).cpu().to(dtype=dtype)

            svd_v = torch.matmul(
                input=sqrtSigma,
                other=truc_v,
            ).cpu().to(dtype=dtype)
            ## -----

            ## Assign the U and V matrices to the first expert of the layer.
            if 'down_proj' in name:
                svd_mlp.down_u_proj.weight.data = svd_u
                svd_mlp.down_v_proj.weight.data = svd_v
            elif 'gate_proj' in name:
                svd_mlp.gate_u_proj.weight.data = svd_u
                svd_mlp.gate_v_proj.weight.data = svd_v
            elif 'up_proj' in name:
                svd_mlp.up_u_proj.weight.data = svd_u
                svd_mlp.up_v_proj.weight.data = svd_v
                layer.mlp = svd_mlp
            ## -----

            W = None
            W_scale = None
            scaling_diag_matrix = None
            scaling_matrix_inv = None
            U = None
            S = None
            VT = None
            truc_s = None
            truc_u = None
            truc_v = None
            sqrtSigma = None

            del W
            del W_scale
            del scaling_matrix_inv
            del scaling_diag_matrix
            del U
            del S
            del VT
            del truc_s
            del truc_u
            del truc_v
            del sqrtSigma

        del layer
        torch.cuda.empty_cache()


@torch.no_grad()
def whiten(
    base_model: PreTrainedModel,
    model: PreTrainedModel,
    layers_profiling: Dict[int, Dict[str, torch.Tensor]],
    layers_profiling_name: str,
    ratio: float,
    cache_dir: str,
    local_rank: int,
    device: str,
    expert_idx: int,
    eps: float = 1e-6,
):
    """ Do whitening and SVD.

    Args:
        base_model (PreTrainedModel): The base model to be whitened.
        model (PreTrainedModel): The MoLoS model.
        layers_profiling (Dict[int, Dict[str, torch.Tensor]]): The scaling diagonal matrices.
        layers_profiling_name (str): The name of the scaling diagonal matrices.
        ratio (float): The expert ratio.
        cache_dir (str): The cache directory.
        local_rank (int): The local rank of the process.
        device (str): The device to be used.
        expert_idx (int): The index of the expert to be assigned. This parameter is only used for debugging.
        eps (float): A small value to stabilize the computation. Defaults to 1e-6.
    """

    base_model.eval()
    model.eval()

    base_layers = base_model.model.layers
    layers = model.model.layers

    for layer_idx in tqdm(
            iterable=range(len(base_layers)),
            desc='[Whitening & Doing SVD]',
            disable=True if local_rank != 0 else False,
            dynamic_ncols=True,
    ):
        base_layer = base_layers[layer_idx]
        layer = layers[layer_idx]

        submodules = find_modules(module=base_layer)

        molos_mlp = MoLoSLlamaMLP(
            hidden_size=layer.hidden_size,
            intermediate_size=layer.intermediate_size,
            hidden_act=model.config.hidden_act,
            ratio=ratio,
        )

        cache_path = os.path.join(
            cache_dir,
            f'{layers_profiling_name}_{str(ratio).replace(".", "d")}_{layer_idx}.pt',
        )

        # Try to load the MoLoS MLP from the cache.
        if os.path.exists(path=cache_path):
            molos_mlp = torch.load(
                f=cache_path,
                map_location='cpu',
                weights_only=False,
            )
        # Build the MoLoS MLP with SVD-LLM method.
        else:
            for name in submodules:
                W = submodules[name].weight.data.float().to(device=device)
                W_dtype = W.dtype
                W_shape = W.shape

                scaling_diag_matrix = \
                    layers_profiling[layer_idx][name].to(device=device)

                # Calculate the inverse of the scaling diagonal matrix.
                try:
                    scaling_diag_matrix_inv = \
                        torch.linalg.inv(A=scaling_diag_matrix)
                except Exception as exception:
                    print('Warning: `scaling_diag_matrix` is NOT full rank.')

                    scaling_diag_matrix += (eps * torch.eye(
                        n=scaling_diag_matrix.shape[0],
                        dtype=scaling_diag_matrix.dtype,
                        device=device,
                    ))
                    scaling_diag_matrix_inv = \
                        torch.linalg.inv(A=scaling_diag_matrix)

                scaling_diag_matrix = scaling_diag_matrix.float()
                scaling_diag_matrix_inv = scaling_diag_matrix_inv.float()

                ## Compute the SVD of the weight matrix.
                W = torch.matmul(
                    input=W,
                    other=scaling_diag_matrix,
                )

                U, S, VT = torch.linalg.svd(
                    A=W,
                    full_matrices=False,
                )
                ## -----

                ## Truncate the singular values and depending vectors.
                rank = int((W_shape[0] * W_shape[1] * ratio) /
                           (W_shape[0] + W_shape[1]))

                S = S[:rank]
                U = U[:, :rank]
                VT = torch.matmul(
                    input=VT[:rank, :],
                    other=scaling_diag_matrix_inv,
                )
                ## -----

                ## Convert the singular values to a diagonal matrix, and compute its square root.
                S = torch.diag(input=S)
                sqrt_S = torch.sqrt(input=S)
                ## -----

                ## Compute the final U and V matrices.
                svd_U = torch.matmul(
                    input=U,
                    other=sqrt_S,
                ).cpu().to(dtype=W_dtype)

                svd_V = torch.matmul(
                    input=sqrt_S,
                    other=VT,
                ).cpu().to(dtype=W_dtype)
                ## -----

                ## Assign the U and V matrices to the first expert of the layer.
                if 'down_proj' in name:
                    molos_mlp.down_u_proj.weight.data = svd_U
                    molos_mlp.down_v_proj.weight.data = svd_V
                elif 'gate_proj' in name:
                    molos_mlp.gate_u_proj.weight.data = svd_U
                    molos_mlp.gate_v_proj.weight.data = svd_V
                elif 'up_proj' in name:
                    molos_mlp.up_u_proj.weight.data = svd_U
                    molos_mlp.up_v_proj.weight.data = svd_V
                ## -----

                W = None
                scaling_diag_matrix = None
                scaling_diag_matrix_inv = None
                U = None
                S = None
                VT = None
                sqrt_S = None

                del W
                del scaling_diag_matrix_inv
                del scaling_diag_matrix
                del U
                del S
                del VT
                del sqrt_S

            if local_rank == 0:
                torch.save(
                    obj=molos_mlp,
                    f=cache_path,
                )

        ## Assign the MoLoS MLP to the first expert, and sync the experts.
        # layer.moe.experts[0] = molos_mlp
        # layer.moe.sync_experts()
        ## -----

        # Assign the MoLoS MLP to the specified expert.
        layer.moe.experts[expert_idx] = molos_mlp

        # Release memory.
        del base_layer
        del layer
        torch.cuda.empty_cache()
