"""Perform coordinate checks for MUP parameterization.

Performs a coordinate check to see if the preactivations diverge as a function
of width. It is a recommended check to see if the MUP has been set up correc-
tly.

Inspired by: https://github.com/microsoft/mup/blob/main/mup/coord_check.py
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch.nn import Module
    from torch import Tensor
    from collections.abc import Callable

import os

import torch

import matplotlib.pyplot as plt

from helpers.logger import get_logger
from helpers.hooks import GetHookVals

logger = get_logger()
logger.level = 20


class CheckActivations:
    """Compute and plot the coordinates for the MUP."""
    def __init__(
            self,
            models: dict[int, list[Module]],
            steps: int,
            dataloader: Module,
            optimizer: Module,
            loss_fn: Callable[[Tensor, Tensor], Tensor],
            lr: float,
            device: str="cuda",
    ) -> None:
        self.models = models
        self.steps = steps
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.lr = lr
        self.device = device

        self.results = {}

    def train_step(
            self,
            parameterization: str,
    ) -> None:
        """Compute the coordinates for a few train steps."""
        for width in self.models.keys():
            self.results[width] = {}
            for name, module in self.models[width][0].named_modules():
                if "layer" in name:
                    self.results[width][name] = {n:[] for n in range(self.steps)}

        for width in self.models.keys():
            for m, model in enumerate(self.models[width]):
                model = model.to(self.device)
                model = model.train()
                if parameterization == "mfp":
                    optim = self.optimizer(
                        model.parameters(),
                        lr=self.lr * width,
                    )
                else:
                    optim = self.optimizer(
                        model.parameters(),
                        lr=self.lr,
                    )

                for batch_idx, (X,y) in enumerate(self.dataloader):
                    if batch_idx == self.steps:
                        break

                    remove_hooks = []
                    hookval = GetHookVals()
                    # register hooks for current model
                    for name, module in model.named_modules():
                        if "layer" in name:
                            remove_hooks.append(
                                module.register_forward_hook(
                                    hookval.getActivation(
                                        name=str(name),
                                    ),
                                ),
                            )

                    # pass batch through model to get activations
                    outputs = model.forward(X.to(self.device))
                    loss = self.loss_fn(y.to(self.device), outputs)

                    # perform parameter update step
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    
                    # record activations and save in main results dictionary
                    for layer in hookval.activation.keys():
                        # compute l1 average
                        l1_avg = torch.abs(
                            hookval.activation[layer],
                        ).mean(dtype=torch.float32)

                        self.results[width][layer][batch_idx].append(l1_avg)
                                            
                    # detach hooks
                    for hook in remove_hooks:
                        hook.remove()

    def plot(self, path: str) -> None:
        """Plot coordinates."""
        
        colors = ["C0", "C1", "C2", "C3", "C4"]
        check = self.results

        # retrieve time steps
        time_step = {
            step:{
                layer: {
                    "widths": [],
                    "mean": [],
                    "error_plus": [],
                    "error_minus": [],
                } for layer in self.results[next(iter(self.results))].keys()
            } for step in range(self.steps)
        }
        
        for width in self.results.keys():
            data = self.results[width]
            for layer in data.keys():
                for step in data[layer].keys():
                    mean = torch.mean(
                        torch.stack(data[layer][step]),
                    ).detach().cpu().numpy()
                    std = torch.std(
                        torch.stack(data[layer][step]),
                    ).detach().cpu().numpy()

                    time_step[step][layer]["widths"].append(width)
                    time_step[step][layer]["mean"].append(mean)
                    time_step[step][layer]["error_plus"].append(mean + std)
                    time_step[step][layer]["error_minus"].append(mean - std)

        for step in time_step.keys():
            fig, ax = plt.subplots()
            for c, layer in enumerate(time_step[step].keys()):
                ax.loglog(
                    time_step[step][layer]["widths"],
                    time_step[step][layer]["mean"],
                    label=layer,
                    color=colors[c],
                    marker="o",
                    base=2,
                    nonpositive="clip",
                )
                ax.fill_between(
                    time_step[step][layer]["widths"],
                    time_step[step][layer]["error_minus"],
                    time_step[step][layer]["error_plus"],
                    color=colors[c],
                    alpha=0.5,
                )

            ax.grid()
            ax.grid(which="minor", color="0.9")
            ax.set_xlabel("Width")
            ax.set_ylabel("l1")
            ax.set_title(f"Preactivations for time step {step}")
            ax.legend()
            file = os.path.join(path, str(step))
            plt.savefig(file)
            plt.close()
