from __future__ import annotations

from typing import Tuple, Iterable

import torch
import wandb


def calculate_grad_norm(model: torch.nn.Module) -> torch.Tensor:
    grad_sum_squared: torch.Tensor | None = None
    for param in model.parameters():
        if param.grad is not None:
            # detach gradient to avoid autograd tracking and flatten
            g = param.grad.detach().view(-1)
            s = torch.sum(g * g)
            if grad_sum_squared is None:
                # initialise on the same device as the gradient
                grad_sum_squared = s.clone()
            else:
                grad_sum_squared = grad_sum_squared + s
    if grad_sum_squared is None:
        # No gradients present; return a zero tensor on CPU
        return torch.tensor(0.0)
    return torch.sqrt(grad_sum_squared)


def init_grad_flow_data(model: torch.nn.Module) -> dict[str, list[float]]:

    grad_flow_data: dict[str, list[float]] = {'steps': []}

    # Helper to register layers for a given module name
    def _register_layers(module_name: str, module: torch.nn.Module) -> None:
        seen: set[str] = set()
        for name, param in module.named_parameters(recurse=True):
            if not param.requires_grad:
                continue
            # Exclude biases completely
            if 'bias' in name.split('.')[-1] or 'bias' in name:
                continue
            # Determine the logical layer name
            if name.startswith('layers.'):
                parts = name.split('.')
                # parts[1] should be the block index
                idx = parts[1] if len(parts) > 1 else '0'
                # left pad single digit indices with a leading zero
                idx = idx if len(idx) == 2 else f'0{idx}'
                layer_name = f'layer_{idx}'
            else:
                layer_name = name.split('.')[0]
            if layer_name not in seen:
                seen.add(layer_name)
                grad_flow_data[f'{module_name}/{layer_name}'] = []
    # Check whether the model defines any child modules
    children = list(model.named_children())
    if not children:
        # The model has parameters directly on the top level
        _register_layers('model', model)
    else:
        for module_name, module in children:
            _register_layers(module_name, module)
    return grad_flow_data


def update_grad_flow_data(model: torch.nn.Module, grad_flow_data: dict[str, list[float]]) -> None:
    # Append a new step count; use the current length as the step index
    step_idx = len(grad_flow_data.get('steps', []))
    grad_flow_data.setdefault('steps', []).append(step_idx)
    # Helper to process a single module
    def _process_module(module_name: str, module: torch.nn.Module) -> None:
        layers, avg_grads, _ = _calculate_grad_flow(module.named_parameters(recurse=True))
        for layer_name, avg_grad in zip(layers, avg_grads):
            key = f'{module_name}/{layer_name}'
            if key not in grad_flow_data:
                # Initialise a new entry if it wasn't registered (e.g. new layer appears)
                grad_flow_data[key] = []
            # Convert to float for wandb; handle Tensor vs numeric types
            if isinstance(avg_grad, torch.Tensor):
                try:
                    avg_val: float = avg_grad.item()
                except Exception:
                    avg_val = float(avg_grad)
            else:
                avg_val = float(avg_grad)
            grad_flow_data[key].append(avg_val)
    children = list(model.named_children())
    if not children:
        _process_module('model', model)
    else:
        for module_name, module in children:
            _process_module(module_name, module)


def format_grad_flow_logs(model: torch.nn.Module, grad_flow_data: dict[str, list[float]]) -> dict[str, wandb.plotting.PlotlyFacet]:

    assert 'steps' in grad_flow_data, "grad_flow_data must contain a 'steps' entry"

    steps: list[int] = grad_flow_data['steps']
    log_dict: dict[str, wandb.plotting.PlotlyFacet] = {}
    # Helper to build a plot for a single module
    def _build_plot(module_name: str, module: torch.nn.Module) -> None:
        # Determine unique logical layer names in a stable order
        seen: set[str] = set()
        layer_names: list[str] = []
        for name, param in module.named_parameters(recurse=True):
            if not (param.requires_grad and param.grad is not None):
                continue
            # skip biases
            if 'bias' in name.split('.')[-1] or 'bias' in name:
                continue
            if name.startswith('layers.'):
                parts = name.split('.')
                idx = parts[1] if len(parts) > 1 else '0'
                idx = idx if len(idx) == 2 else f'0{idx}'
                layer_name = f'layer_{idx}'
            else:
                layer_name = name.split('.')[0]
            if layer_name not in seen:
                seen.add(layer_name)
                layer_names.append(layer_name)
        if not layer_names:
            return
        keys: list[str] = []
        ys: list[list[float]] = []
        for layer_name in layer_names:
            series_key = f'{module_name}/{layer_name}'
            y = grad_flow_data.get(series_key)
            if not y:
                continue
            # Ensure the y series is the same length as steps; pad with last known value if necessary
            if len(y) < len(steps):
                if y:
                    pad_value = y[-1]
                else:
                    pad_value = 0.0
                y = y + [pad_value] * (len(steps) - len(y))
            keys.append(layer_name)
            ys.append(y)
        if ys:
            log_dict[module_name] = wandb.plot.line_series(
                xs=steps,
                ys=ys,
                keys=keys,
                title=f"{module_name} grad flow (normalised by weight values)",
                xname="steps",
            )
    children = list(model.named_children())
    if not children:
        _build_plot('model', model)
    else:
        for module_name, module in children:
            _build_plot(module_name, module)
    return log_dict


def _calculate_grad_flow(named_parameters: Iterable[Tuple[str, torch.nn.Parameter]], *, epsilon: float = 1e-13, norm: str = 'l2') -> Tuple[list[str], list[float], list[float]]:

    avg_grads: list[float] = []
    avg_weights: list[float] = []
    max_grads: list[float] = []
    layers: list[str] = []
    # Helper to normalise a layer index
    def _layer_name_from(name: str) -> str:
        if name.startswith('layers.'):
            parts = name.split('.')
            idx = parts[1] if len(parts) > 1 else '0'
            idx = idx if len(idx) == 2 else f'0{idx}'
            return f'layer_{idx}'
        return name.split('.')[0]
    for name, param in named_parameters:
        if not param.requires_grad:
            continue
        if param.grad is None:
            continue
        # Skip biases (either named exactly 'bias' or containing 'bias' anywhere)
        if 'bias' in name.split('.')[-1] or 'bias' in name:
            continue
        layer_name = _layer_name_from(name)
        # If this layer hasn't been seen yet, initialise its accumulators
        if layer_name not in layers:
            layers.append(layer_name)
            avg_grads.append(0.0)
            avg_weights.append(0.0)
            max_grads.append(0.0)
        # index of current layer
        idx = layers.index(layer_name)
        # detach grads and weights to CPU for safe operations
        p_grad = param.grad.detach().to('cpu')
        p_weight = param.detach().clone().to('cpu')
        if norm == 'l2':
            avg_grads[idx] += float(torch.sum(p_grad.square()).item())
            avg_weights[idx] += float(torch.sum(p_weight.square()).item())
        elif norm == 'l1':
            avg_grads[idx] += float(torch.sum(p_grad.abs()).item())
            avg_weights[idx] += float(torch.sum(p_weight.abs()).item())
        else:
            raise ValueError(f"unknown norm value: {norm!r}")
        current_max = float(p_grad.abs().max().item())
        # update maximum gradient seen so far for this layer
        if current_max > max_grads[idx]:
            max_grads[idx] = current_max
    # Finalise average gradients by normalising with weight norms
    normalised_avg_grads: list[float] = []
    for g_sum, w_sum in zip(avg_grads, avg_weights):
        if norm == 'l2':
            # sqrt(g_sum / (w_sum + epsilon))
            ratio = g_sum / (w_sum + epsilon)
            normalised = ratio ** 0.5
        elif norm == 'l1':
            normalised = g_sum / (w_sum + epsilon)
        else:
            raise ValueError(f"unknown norm value: {norm!r}")
        normalised_avg_grads.append(normalised)
    return layers, normalised_avg_grads, max_grads
