from __future__ import annotations

from collections import defaultdict
from pathlib import Path

import lightning as L
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from litgpt.config import Config
from litgpt.utils import num_parameters
from mup import get_shapes, make_base_shapes, coord_check

from saws.model import GPT_Scales


def get_effective_batch_size(
    device_count: int,
    accumulation_iters: int,
    micro_batch_size: int
) -> int:
    return int(device_count * accumulation_iters * micro_batch_size)


def get_total_data_loader_steps(
    train_tokens: int,
    block_size: int,
    device_count: int,
    accumulation_iters: int,
    micro_batch_size: int,
) -> int:
    # calculates the effective batch size of the current training
    effective_batch_size = get_effective_batch_size(
        device_count,
        accumulation_iters,
        micro_batch_size
    )
    # NOTE: not a strict calculation, could lead to overlaps in the first set of micro batches
    total_minibatches = int(train_tokens // (effective_batch_size * block_size))
    num_micro_batches = total_minibatches * accumulation_iters
    return num_micro_batches


def total_l2_norm(model: nn.Module) -> float:
    total_norm = 0
    parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
    for p in parameters:
        param_norm = p.detach().data.norm(2)
        total_norm += param_norm.item() ** 2
    return total_norm ** 0.5


def total_l1_norm(model: nn.Module) -> float:
    total_norm = 0
    parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
    for p in parameters:
        param_norm = p.detach().data.norm(1)
        total_norm += param_norm.item()
    return total_norm


def l2_norm_per_layer(model: nn.Module, global_step: int) -> dict:
    layer_grad_norms = defaultdict(list)
    for name, param in model.named_parameters():
        if "transformer.h" in name and param.grad is not None and param.requires_grad:
            layer_id = name.split(".transformer.h.")[-1].split(".")[0]  # extract the layer ID
            layer_grad_norms[layer_id].append(param.detach().norm(2).item() ** 2)
    # calculating norm for each layer by summing the square of each parameter's gradient
    return {k: np.sum(v) ** 0.5 for k, v in layer_grad_norms.items()}


def l1_norm_per_layer(model: nn.Module, global_step: int) -> dict:
    layer_grad_norms = defaultdict(list)
    for name, param in model.named_parameters():
        if "transformer.h" in name and param.grad is not None and param.requires_grad:
            layer_id = name.split(".transformer.h.")[-1].split(".")[0]  # extract the layer ID
            layer_grad_norms[layer_id].append(param.detach().norm(1).item())
    # calculating norm for each layer by summing the square of each parameter's gradient
    return {k: np.sum(v) for k, v in layer_grad_norms.items()}


def total_gradient_l2_norm(model: nn.Module) -> float:
    total_norm = 0
    parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
    for p in parameters:
        param_norm = p.grad.detach().data.norm(2)
        total_norm += param_norm.item() ** 2
    return total_norm ** 0.5


def coord_check_util(
    model: torch.nn.Module, val_dataloader: torch.utils.data.DataLoader, max_seq_length: int, metrics: list
) -> list:
    batch = next(iter(val_dataloader))
    input_ids = batch[:, 0:max_seq_length].contiguous().long()
    remove_hooks = []
    records = []
    FDICT = {n: v for n, v in coord_check.FDICT.items() if n in metrics}
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            hook = module.register_forward_hook(
                coord_check._record_coords(
                    records=records,
                    width=None,
                    modulename=name,
                    t=0,
                    output_fdict=FDICT,
                    input_fdict={},
                    param_fdict={},
                )
            )
            remove_hooks.append(hook)
    with torch.no_grad():
        model(input_ids)
    for hook in remove_hooks:
        hook.remove()
    return records


def gradient_l2_norm_per_layer(model: nn.Module, global_step: int) -> dict:
    layer_grad_norms = defaultdict(list)
    for name, param in model.named_parameters():
        if "transformer.h" in name and param.grad is not None and param.requires_grad:
            layer_id = name.split(".transformer.h.")[-1].split(".")[0]  # extract the layer ID
            layer_grad_norms[layer_id].append(param.grad.detach().norm(2).item() ** 2)
    # calculating norm for each layer by summing the square of each parameter's gradient
    return {k: np.sum(v) ** 0.5 for k, v in layer_grad_norms.items()}


def weight_spectra(model: nn.Module) -> dict:
    singular_val_per_layer = {}
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            singular_vals = torch.linalg.svdvals(mod.weight.data).detach()
            singular_val_per_layer[name] = singular_vals
    return singular_val_per_layer


def get_mup_shape_base(base_config: Config, target_config: Config, output_file: Path, verbose: bool = False) -> None:
    """Get the shape difference between two models with different scaling dimensions for muP.

    Refer to the `../examples/save_model_base_shape.py` script for more details.

    """
    base_model = get_shapes(GPT_Scales(base_config, mup_init=True))
    delta_model = get_shapes(GPT_Scales(target_config, mup_init=True))
    if isinstance(output_file, str):
        output_file = Path(output_file)
    make_base_shapes(base_model, delta_model, output_file)
    print(f"Scaling shape saved to {output_file.absolute()}!")
    if verbose:
        print(
            "\nNumber of base:target parameters (Kaplan): "
            f"{count_trainable_parameters_kaplan(GPT_Scales(base_config, mup_init=True)) / 1e6}M:"
            f"{count_trainable_parameters_kaplan(GPT_Scales(target_config, mup_init=True)) / 1e6}M"
        )
        print(
            "\nNumber of base:target parameters (Chinchilla): "
            f"{count_trainable_parameters_chinchilla(GPT_Scales(base_config, mup_init=True)) / 1e6}M:"
            f"{count_trainable_parameters_chinchilla(GPT_Scales(target_config, mup_init=True)) / 1e6}M"
        )
        print(
            "\nNumber of base:target parameters (LitGPT): "
            f"{num_parameters(GPT_Scales(base_config, mup_init=True), requires_grad=True) / 1e6}M:"
            f"{num_parameters(GPT_Scales(target_config, mup_init=True), requires_grad=True) / 1e6}M"
        )


def count_trainable_parameters_kaplan(model):
    """Count the number of parameters in a PyTorch model using the Kaplan approach.

    In the Kaplan et al. paper "Scaling Laws for Neural Language Models",
    they count all parameters in the model, including embeddings.

    NOTE: Generated from Claude 3.5 Sonnet on August 13, 2024.

    Args:
    model (nn.Module): PyTorch model

    Returns:
    int: Total number of parameters

    """
    # TODO: verify code

    return sum(p.numel() for p in model.parameters())


def count_trainable_parameters_chinchilla(
    model: nn.Module, return_all: bool = False, verbose: bool = False
) -> int | tuple[int, int]:
    """Count the number of parameters in a PyTorch model using an interpretation of the Chinchilla approach.

    Based on the Hoffmann et al. paper "Training Compute-Optimal Large Language Models",
    this function attempts to exclude embedding parameters. However, the exact definition
    of what constitutes "embedding parameters" may vary depending on the model architecture.

    NOTE: Generated from Claude 3.5 Sonnet on August 13, 2024.

    Args:
    model (nn.Module): PyTorch model

    Returns:
    int: Estimated number of non-embedding parameters
    or
    tuple[int, int]: Total number of parameters and embedding parameters

    """
    # TODO: verify code

    def is_embedding_like(module):
        return isinstance(module, (nn.Embedding, nn.EmbeddingBag))

    total_params = 0
    embedding_params = 0

    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # it's a leaf module
            module_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
            if is_embedding_like(module):
                embedding_params += module_params
            else:
                total_params += module_params

    if verbose:
        print(f"Total parameters: {total_params + embedding_params}")
        print(f"Embedding parameters: {embedding_params}")
        print(f"Non-embedding parameters: {total_params}")

    if return_all:
        return total_params, embedding_params
    return embedding_params


def calculate_flops(steps: pd.Series, info: dict) -> pd.Series:
    """
    Calculates the number of FLOPs for each step based on the 6ND formula used in
    https://arxiv.org/pdf/2203.15556 and https://arxiv.org/pdf/2001.08361
    """
    tokens_per_step = info['scales']['block_size'] * info['effective_batch_size']
    parameters = info['parameters']
    return steps * (tokens_per_step * parameters * 6)


def calculate_tokens(steps: pd.Series, info: dict) -> pd.Series:
    tokens_per_step = info['scales']['block_size'] * info['effective_batch_size']
    return steps * tokens_per_step


def mean_l2_weight_norm(model: nn.Module) -> float:
    sum_l2 = 0
    parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
    for i, p in enumerate(parameters, start=1):
        sum_l2 += p.detach().data.norm(2)
    return sum_l2 / i


def mean_l1_weight_norm(model: nn.Module) -> float:
    sum_l1 = 0
    parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
    for i, p in enumerate(parameters, start=1):
        sum_l1 += p.detach().data.norm(1)
    return sum_l1 / i
