
"""Code for the gradient experiment of the extrinsic transformer paper.

This script is used to generate the figures of the gradient experiment of the
extrinsic transformer paper. It is used to compare the gradients of the
transformer and the GEN on the baseline datasets.

How to run:
    python gradient.py --path path_to_outputs
"""
import argparse
import logging
import os
from pathlib import Path

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import yaml
from dotenv import load_dotenv

from exttfs.datasets.baseline_dataset import datasets
from exttfs.models.gen import GEN, GraphStructure, grid
from exttfs.models.msa import MSAEncoderOnly
from exttfs.models.transformer import EncoderOnly

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)

load_dotenv()
log = logging.getLogger(__name__)


def patched_tfs_forward(self, cx, cy, tx):
    """Replace the forward pass of the transformer to see the latents.

    Args:
        cx (torch.Tensor): Context points
        cy (torch.Tensor): Context values
        tx (torch.Tensor): Target points

    Returns:
        torch.Tensor: Target values
    """
    latents = self.encoder(torch.cat((cx, cy), dim=-1))

    for block in self.blocks:
        latents = block(torch.cat((cx, latents), dim=-1))

    self.latents = latents
    self.latents.retain_grad()

    self.scores = self.scaled_dist(tx, cx)
    self.scores.retain_grad()
    z = self.scores.bmm(latents)
    return self.decoder(torch.cat((z, tx), dim=-1))


def patched_msa_forward(self, cx, cy, tx):
    """Replace the forward pass of the MSA to see the latents.

    Args:
        cx (torch.Tensor): Context points
        cy (torch.Tensor): Context values
        tx (torch.Tensor): Target points

    Returns:
        torch.Tensor: Target values
    """
    latents_context = self.context_encoder(torch.cat((cx, cy), dim=-1))
    latents_target = self.target_encoder(tx)
    latents = torch.cat((latents_context, latents_target), dim=1)
    L = tx.shape[1]

    for block in self.blocks:
        latents = block(latents)
    self.latents = latents
    self.latents.retain_grad()

    return self.decoder(latents[:, -L:, :])


def patched_gen_forward(self, cx, cy, tx):
    """Replace the forward pass of the GEN to see the latents.

    Args:
        cx (torch.Tensor): Context points
        cy (torch.Tensor): Context values
        tx (torch.Tensor): Target points

    Returns:
        torch.Tensor: Target values
    """
    p = self.g.pos.unsqueeze(0).repeat(len(cx), 1, 1)
    scores = self.g(cx)
    emb = self.encoder(torch.cat((cx, cy), dim=-1))
    latents = scores.transpose(1, 2).bmm(emb)
    for block in self.gn_blocks:
        latents = block(
            torch.cat((p, latents), dim=-1), self.g.senders, self.g.receivers
        )
    self.latents = latents
    self.latents.retain_grad()

    self.scores = self.g(tx)
    self.scores.retain_grad()
    z = self.scores.bmm(self.latents)
    return self.decoder(torch.cat((z, tx), dim=-1))


parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, default="outputs")
args = parser.parse_args()

if __name__ == "__main__":
    pts = list(Path(args.path).glob("**/*.pt"))
    config_files = [p.parent / ".hydra/config.yaml" for p in pts]

    cfgs = []
    for c in config_files:
        with c.open() as f:
            cfgs.append(yaml.safe_load(f))

    models = []
    for cfg in cfgs:
        match cfg:
            case {"model": {"name": "gen"}}:
                gs = GraphStructure(*grid(cfg["model"]["grid_size"]), fixed=False)
                model = GEN(gs, **cfg["model"]["params"])
                model.forward = patched_gen_forward.__get__(model)
                models.append(model)
            case {"model": {"name": "encoder_only"}}:
                model = EncoderOnly(**cfg["model"]["params"])
                model.forward = patched_tfs_forward.__get__(model)
                models.append(model)
            case {"model": {"name": "msa_encoder_only"}}:
                model = MSAEncoderOnly(**cfg["model"]["params"])
                model.forward = patched_msa_forward.__get__(model)
                models.append(model)
            case _:
                models.append(None)

    pts = [p for p, m in zip(pts, models) if m is not None]
    cfgs = [c for c, m in zip(cfgs, models) if m is not None]
    models = [m for m in models if m is not None]
    print("Loading models")
    for m, p in zip(models, pts):
        m.load_state_dict(torch.load(p, map_location="cpu"))
        m.eval()

    names = []
    for cfg in cfgs:
        match cfg:
            case {"model": {"name": "gen", "grid_size": g}}:
                names.append(f"GEN {g}x{g}")
            case {"model": {"name": "encoder_only"}}:
                names.append("TFS")
            case {"model": {"name": "msa_encoder_only"}}:
                names.append("ZMSA")

    names, models = zip(*sorted(zip(names, models), key=lambda t: t[0], reverse=True))

    train_dataset, val_dataset = datasets("cpu", 64, "sin", 10)

    dl = torch.utils.data.DataLoader(val_dataset, batch_size=1000, shuffle=True)
    (cx, cy, tx), ty = next(iter(dl))

    n = 1

    print("Computing gradients")
    for model in models:
        for (cx, cy, tx), _ in dl:
            output = model(cx, cy, tx)
            o2 = output.detach().clone()
            o2[:, n, :] += 200.0
            loss = F.mse_loss(output, o2)
            loss.backward()

    print("Plotting")
    fig = plt.figure(figsize=(4.5, 5))
    gs = fig.add_gridspec(len(models) + 2, 64, left=0.18)

    axs = [fig.add_subplot(gs[0, :])]

    axs[0].imshow(
        F.mse_loss(output, o2, reduction="none").sum(0, keepdims=True).detach().numpy(),
        cmap="binary",
    )
    axs[0].set_yticks([0.0], ["Error"])
    axs[0].set_xticks([0, 64])

    i = 0
    for model, name_with_z in zip(models, names):
        name = name_with_z.replace("Z", "")

        grads = model.latents.grad[0].abs().sum(1).unsqueeze(0).detach().numpy()

        if name == "MSA":
            N = 64
            ax = fig.add_subplot(gs[i + 1, :N])
            ax.imshow(grads[:, :N], cmap="binary", vmin=0.0, vmax=1e-8)
            ax.set_yticks([0.0], [name + " ctx"])
            if N == 1:
                ax.set_xticks([1])
            else:
                ax.set_xticks([0, N])

            i += 1
            axs.append(ax)

            ax = fig.add_subplot(gs[i + 1, :N])
            ax.imshow(grads[:, N:], cmap="binary", vmin=0.0, vmax=1e-8)
            ax.set_yticks([0.0], [name + " tgt"])
        else:
            N = grads.shape[1]
            ax = fig.add_subplot(gs[i + 1, :N])

            ax.imshow(grads, cmap="binary", vmin=0.0, vmax=1e-8)
            ax.set_yticks([0.0], [name])

        if N == 1:
            ax.set_xticks([1])
        else:
            ax.set_xticks([0, N])
        axs.append(ax)

        i += 1

    for ax in axs:
        ax.set_aspect("equal")
        for loc in ["top", "left", "right"]:
            ax.spines[loc].set_visible(False)

    fig.savefig(f"grads{n}.pdf")
