import copy
import logging
import math
import os
import random
from argparse import ArgumentParser, Namespace
from functools import partial
from typing import Any, List, Optional, Tuple

import numpy as np
import pandas as pd  # type: ignore
import seaborn as sns  # type: ignore
import torch
from matplotlib import pyplot as plt  # type: ignore
from scipy.stats import multivariate_normal  # type: ignore
from torch import nn
from torch.distributions import MultivariateNormal, Wishart
from torch.nn import functional as F

T = torch.Tensor
cm = sns.color_palette("mako", as_cmap=True)

plt.rcParams['text.usetex'] = True

def set_sns() -> None:
    sns.set_theme(style="white")
    sns.color_palette("tab10")
    sns.set_context(
        "notebook",
        font_scale=1.8,
        rc={
            "lines.linewidth": 5,
            "lines.markerscale": 5,
        }
    )

set_sns()

def seed(run: int) -> None:
    torch.manual_seed(run)
    random.seed(run)
    np.random.seed(run)


def get_data(n: int, dim: int, n_dists: int = 10, _seed: int = 0) -> Tuple[T, T, List[MultivariateNormal]]:
    """samples a random mean and random covariance from a Wishart distribution"""
    seed(_seed)
    data_train, data_val, dists = [], [], []
    for i in range(n_dists):
        mean = torch.rand(dim) * 6 - 3
        cov = Wishart(
            covariance_matrix=torch.diag_embed(torch.rand(dim) * 0.1 + 0.25),
            df=torch.tensor([int(dim * 1.5)])
        ).sample()[0]

        data_dist = MultivariateNormal(mean, covariance_matrix=cov)
        data_train.append(data_dist.sample([n]) + mean)
        data_val.append(data_dist.sample([n]) + mean)
        dists.append(data_dist)

    data_train_out = torch.cat(data_train)
    data_val_out = torch.cat(data_val)

    return data_train_out, data_val_out, dists


def contours(ax: Any, mu: T, sigma: T, steps: int = 100) -> Any:
    mx = np.linspace(-7, 7, steps)
    my = np.linspace(-7, 7, steps)
    xx, yy = np.meshgrid(mx, my)

    rv = multivariate_normal(mu.tolist(), sigma.tolist())
    data = np.dstack((xx, yy))
    z = rv.pdf(data)
    c = ax.contour(mx + mu[0].cpu().numpy(), my + mu[1].cpu().numpy(), z, 3, alpha=0.75, linewidths=4.0, cmap=cm)


def plot_data(outpath: str) -> None:
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))

    df: Any = {"x": [], "y": [], "label": []}

    data, _, true_dists = get_data(50, 2, n_dists=4)
    df["x"].extend(data[:, 0].numpy().tolist())
    df["y"].extend(data[:, 1].numpy().tolist())
    df["label"].extend([math.ceil((1 + i) / 50) for i in range(50 * 4)])

    for mvn in true_dists:
        contours(ax, mvn.loc, mvn.covariance_matrix)

    data = pd.DataFrame(df)
    ax = sns.scatterplot(data=data, x="x", y="y", hue="label", style="label", ax=ax, s=11**2)

    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "data-example.pdf"))
    fig.savefig(os.path.join(outpath, "data-example.png"))
    plt.close()


def update_teacher(alpha: float, S: nn.Linear, T: nn.Linear) -> None:
    with torch.no_grad():
        for ps, pt in zip(S.parameters(), T.parameters()):
            pt.data.mul_(alpha).add_((1 - alpha) * ps.detach().data)


def sample_masked_data(data: T, args: Namespace) -> Tuple[T, T, T]:
    idx = torch.randperm(data.size(0))[:args.batch_size]
    x, y = data[idx], data[idx]

    # create a random permutation for each item in x to select the masked indices (ceil(0.5) of the total dimensions)
    # create a full mask and then set those selected indices to be zero
    masked_components = math.ceil(args.masking_ratio * x.size(1))
    perms = [torch.randperm(x.size(1)) for i in range(x.size(0))]
    masked_dims = torch.stack(perms)[:, :masked_components].reshape(-1).long()
    masked_idx = torch.arange(args.batch_size).unsqueeze(-1).repeat(1, masked_components).view(-1).long()

    mask = torch.ones(x.size(0), x.size(1))
    mask[masked_idx, masked_dims] = 0

    return x * mask, y, mask


def mse_loss(y: T, yhat: T, mask: T) -> T:
    return (((y - yhat) ** 2) * mask).sum() / mask.sum()  # type: ignore


def get_opt(model: nn.Module, args: Namespace) -> Any:
    return torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.sgd_momentum)


def get_grad_experiment_inputs(x: T, diff_x: T, args: Namespace) -> Tuple[Tuple[List[T], ...], ...]:
    def get_mask() -> T:  
        masked_components = math.ceil(args.masking_ratio * x.size(0))
        perm = torch.randperm(x.size(0))
        masked_dims = perm[:masked_components].long()
        mask = torch.ones(x.size(0))
        mask[masked_dims] = 0
        return mask

    # 1. get three versions of a single input (same, similar, different)
    mask, mask_p = get_mask(), get_mask()
    same_x = [torch.clone(x) * mask, torch.clone(x) * mask]
    same_x_masks = [mask, mask]

    similar_x = [torch.clone(x) * mask, torch.clone(x) * mask_p]
    similar_x_masks = [mask, mask_p]

    different_x = [torch.clone(x) * mask, torch.clone(diff_x) * mask_p]
    different_x_masks = [mask, mask_p]
    return (same_x, same_x_masks), (similar_x, similar_x_masks), (different_x, different_x_masks)


def student_grad(yhat: T, x: T, x_in: T, mask: T) -> T:
    return torch.outer(yhat - x, x_in) * mask.unsqueeze(-1)  # type: ignore


def teacher_grad(yhat_S: T, yhat_T: T, x_in: T, mask: T) -> T:
    return torch.outer(yhat_S - yhat_T, x_in) * mask.unsqueeze(-1)  # type: ignore


def spectral_norm(A: T) -> float:
    return torch.sqrt(torch.max(torch.real(torch.linalg.eig(A.T @ A)[0]))).item()


def train_recon(train_data: T, val_data: T, S: nn.Linear, args: Namespace, do_grad_experiment: bool = False) -> Tuple[List[float], T, T, T, T]:
    losses, out = [], []
    opt = get_opt(S, args=args)
    for i in range(args.iters):
        seed(i)
        x, y, mask = sample_masked_data(train_data, args=args)
        yhat = S(x)

        loss = mse_loss(y, yhat, mask)
        opt.zero_grad()
        loss.backward()
        opt.step()

        with torch.no_grad():
            xv, yv, maskv = sample_masked_data(val_data, args=args)
            yhatv = S(xv)
            lossv = mse_loss(yv, yhatv, maskv)
            losses.append(lossv.item())

        # =================================================================================
        # START GRADIENT NORM EXPERIMENT
        if do_grad_experiment:
            total_norms = []
            x, y, mask = sample_masked_data(train_data, args=args)
            for i, single_x in enumerate(y):
                same, similar, different = get_grad_experiment_inputs(single_x, y[i - 1], args)

                # 2. go forward and calculate the gradients for all the inputs and the masks
                norms = []
                for tup in (same, similar, different):
                    x_ins, masks = tup

                    S0 = copy.deepcopy(S)
                    yhat_S0 = S0(x_ins[0].unsqueeze(0)).squeeze()

                    s0_grads = student_grad(yhat_S0, single_x, x_ins[0], masks[0])
                    s0_grads += args.sgd_momentum * opt.state[S.weight]["momentum_buffer"]
                    # print("mae: ", (opt.state[S.weight]["momentum_buffer"] ** 2).sum())

                    # 3. gradient step one 
                    S1 = copy.deepcopy(S0)
                    S1.weight.data = S0.weight.data - args.lr * s0_grads

                    yhat_S1 = S1(x_ins[1].unsqueeze(0)).squeeze()

                    s1_grads = student_grad(yhat_S1, single_x, x_ins[1], masks[1])
                    norms.append(spectral_norm(s1_grads))
                total_norms.append(norms)

            total = torch.tensor(total_norms).mean(dim=0)
            out.append(total.tolist())
            print(f"student only: {total}")
        # =================================================================================
        # END GRADIENT NORM EXPERIMENT

    return losses, T(out), T(), T(), T()


def train_recon_cons(train_data: T, val_data: T, S: nn.Linear, Te: nn.Linear, args: Namespace, do_grad_experiment: bool = False) -> Tuple[List[float], T, T, T, T]:
    losses, out_s, out_t, out_cos_s, out_cos_t = [], [], [], [], []
    opt = get_opt(S, args=args)
    for i in range(args.iters):
        seed(i)
        x, y, mask = sample_masked_data(val_data, args=args)

        yhat_s = S(x)
        with torch.no_grad():
            yhat_t = Te(x)

        recon_loss = mse_loss(y, yhat_s, mask)
        cons_loss = mse_loss(yhat_s, yhat_t, mask)
        opt.zero_grad()
        (recon_loss + cons_loss).backward()
        opt.step()

        with torch.no_grad():
            xv, yv, maskv = sample_masked_data(val_data, args=args)
            yhatv = S(xv)
            lossv = mse_loss(yv, yhatv, maskv)
            losses.append(lossv.item())

        alpha = {
            "": args.teacher_momentum,
            "cosine": 0.5 * (1 + np.cos(np.pi * (i / args.iters))),
            "linear": 1 - (i / args.iters)
        }

        update_teacher(alpha[args.teacher_momentum_schedule], S, Te)

        # =================================================================================
        # START GRADIENT NORM EXPERIMENT
        if do_grad_experiment:
            x, y, mask = sample_masked_data(train_data, args=args)
            total_s_norms, total_t_norms, total_grad_cos_s, total_grad_cos_t = [], [], [], []
            for i, single_x in enumerate(y):
                same, similar, different = get_grad_experiment_inputs(single_x, y[i - 1], args)

                # 2. go forward and calculate the gradients for all the inputs and the masks
                s_grad_norms, t_grad_norms, grad_cos_s, grad_cos_t = [], [], [], []
                for j, tup in enumerate((same, similar, different)):
                    x_ins, masks = tup

                    S0, T0 = copy.deepcopy(S), copy.deepcopy(Te)
                    yhat_S0 = S0(x_ins[0].unsqueeze(0)).squeeze()
                    with torch.no_grad():
                        yhat_T0 = T0(x_ins[0].unsqueeze(0)).squeeze()

                    # old_teacher_grad = teacher_grad(yhat_S0, yhat_T0, x_ins[0], masks[0])
                    s0_grads = student_grad(yhat_S0, single_x, x_ins[0], masks[0]) + teacher_grad(yhat_S0, yhat_T0, x_ins[0], masks[0])

                    # 2. (sanity check to make sure the gradient is correct) through autograd and manual
                    # S0.zero_grad()
                    # loss = mse_loss(single_x, yhat_S0, masks[0]) * masks[0].sum()
                    # loss += mse_loss(yhat_S0, yhat_T0, masks[0]) * masks[0].sum()
                    # grads = torch.autograd.grad(loss / 2, [S0.weight])[0]
                    # if not torch.all(torch.abs(grads - s0_grads) < 1e-3):
                    #     torch.set_printoptions(threshold=10_000)
                    #     print(f"diff: {torch.abs(grads - s0_grads)}")
                    #     print(f"{grads=}\n{s0_grads=}")
                    #     raise ValueError()

                    # add momentum to this gradient step
                    s0_grads += args.sgd_momentum * opt.state[S.weight]["momentum_buffer"]
                    # print("mae+teacher: ", (opt.state[S.weight]["momentum_buffer"] ** 2).sum())

                    # 3. gradient step to t+1
                    S1 = copy.deepcopy(S0)
                    S1.weight.data = S0.weight.data - (args.lr * s0_grads)
                    update_teacher(alpha[args.teacher_momentum_schedule], S1, T0)

                    yhat_S1 = S1(x_ins[1].unsqueeze(0)).squeeze()
                    with torch.no_grad():
                        yhat_T1 = T0(x_ins[1].unsqueeze(0)).squeeze()

                    s1_grads = student_grad(yhat_S1, single_x, x_ins[1], masks[1])
                    t1_grads = teacher_grad(yhat_S1, yhat_T1, x_ins[1], masks[1])

                    # print(f"{j} s1 grads: {s1_grads.view(-1)}\n\ns0 grads: {s0_grads.view(-1)}")
                    grad_cos_s.append(nn.CosineSimilarity(dim=0)((s1_grads).view(-1), s0_grads.view(-1)))
                    grad_cos_t.append(nn.CosineSimilarity(dim=0)((t1_grads).view(-1), s0_grads.view(-1)))
                    s_grad_norms.append(spectral_norm(s1_grads))
                    t_grad_norms.append(spectral_norm(t1_grads))

                # total_norms.append(norms)
                total_grad_cos_s.append(grad_cos_s)
                total_grad_cos_t.append(grad_cos_t)
                total_s_norms.append(s_grad_norms)
                total_t_norms.append(t_grad_norms)
                # total_momentum_effects.append(momentum_effects)

            # total_momentum = T(total_momentum_effects).mean(dim=0)
            total_cos_s = T(total_grad_cos_s).mean(dim=0)
            total_cos_t = T(total_grad_cos_t).mean(dim=0)
            total_s = T(total_s_norms).mean(dim=0)
            total_t = T(total_t_norms).mean(dim=0)

            out_cos_s.append(total_cos_s.tolist())
            out_cos_t.append(total_cos_t.tolist())
            out_s.append(total_s.tolist())
            out_t.append(total_t.tolist())
            print(f"student {total_s } teacher: {total_t}, cos student: {total_cos_s} cos teacher: {total_cos_t}")
        # =================================================================================
        # END GRADIENT NORM EXPERIMENT

    return losses, T(out_s), T(out_t), T(out_cos_s), T(out_cos_t)


class Residual(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.layer = nn.Linear(dim, dim)

    def forward(self, x: T) -> T:
        return x + F.relu(self.layer(x))  # type: ignore


class AutoEncoder(nn.Module):
    def __init__(self, n_layers: int, in_dim: int, h_dim: int) -> None:
        super().__init__()
        self.in_layer = nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU())
        self.layers = nn.Sequential(*[Residual(h_dim) for _ in range(n_layers)])
        self.out_layer = nn.Linear(h_dim, in_dim)

    def forward(self, x: T) -> T:
        x = self.layers(self.in_layer(x))
        return self.out_layer(x)  # type: ignore


def do_runs(args: Namespace) -> pd.DataFrame:
    df: Any = {"loss": [], "type": [], "iter": []}
    for run, (train_data, val_data) in enumerate(args.run_data):
        print(f"{run=}")
        S = {"linear": nn.Linear(args.dim, args.dim), "deep": AutoEncoder(n_layers=1, in_dim=args.dim, h_dim=args.dim // 2)}[args.model]

        names = ["MAE+Teacher", "MAE"]
        funcs = [
            partial(train_recon_cons, train_data, val_data, copy.deepcopy(S), copy.deepcopy(S), args=args),
            partial(train_recon, train_data, val_data, copy.deepcopy(S), args=args)
        ]

        for (name, func) in zip(names, funcs):
            losses, _, _, _, _ = func()

            df["loss"].extend(losses)
            df["type"].extend([name] * len(losses))
            df["iter"].extend([i for i in range(args.iters)])

    return pd.DataFrame(df)


def get_csv(path: str) -> Optional[pd.DataFrame]:
    if os.path.exists(path):
        return pd.read_csv(path)
    return None


def generic_test(outpath: str, ax_title: str, savefile: str) -> None:
    os.makedirs(outpath, exist_ok=True)

    # check to see if the df already exists at the outpath and just load it if it does
    df_path = os.path.join(outpath, f"df-{savefile}.csv")
    df = get_csv(df_path)
    if df is None:
        df = do_runs(args)
        df.to_csv(df_path)

    rows, cols = 1, 1
    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))
    sns.lineplot(data=df, x="iter", y="loss", hue="type", style="type", ax=ax) 
    ax.set_title(ax_title)
    ax.set(ylabel="reconstruction loss")

    fig.tight_layout()
    fig.savefig(os.path.join(outpath, f"{savefile}.png"))
    fig.savefig(os.path.join(outpath, f"{savefile}.pdf"))


PREFIX = ""

def test_heatmaps(args: Namespace) -> None:
    outpath = f"{PREFIX}heatmaps"
    ax_title = f"({args.model}) Loss Heatmap"
    savefile = f"{args.model}-{args.sgd_momentum}-convergence"
    os.makedirs(outpath, exist_ok=True)

    # check to see if the df already exists at the outpath and just load it if it does
    df_path = os.path.join(outpath, f"df-{savefile}.csv")
    df = get_csv(df_path)
    args_before = copy.deepcopy(args)
    if df is None:
        # make a loop for the different args here
        df = pd.DataFrame()
        for sgd_momentum in [0.91, 0.93, 0.95, 0.97, 0.99]:
            args.sgd_momentum = sgd_momentum
            # do all the runs
            df_run = do_runs(args)
            df_run.insert(2, "batch_size", args.batch_size)
            df_run.insert(2, "sgd_momentum", args.sgd_momentum)
            df = pd.concat((df, df_run))
            print(df)

        args.sgd_momentum = args_before.sgd_momentum
        for batch_size in [4, 8, 16, 32, 64, 128, 256]:
            args.batch_size = batch_size
            # do all the runs
            df_run = do_runs(args)
            df_run.insert(2, "batch_size", args.batch_size)
            df_run.insert(2, "sgd_momentum", args.sgd_momentum)
            df = pd.concat((df, df_run))
            print(df)
            # concat the df and add the right column for the appropriate args
        # save the df
        df.to_csv(df_path)

    # make the heatmaps

    df = df.groupby(["type", "sgd_momentum", "batch_size"]).mean()
    df = df.reset_index()
    for name in ["MAE", "MAE+Teacher"]:
        rows, cols = 1, 1
        fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))

        dfp = df[df.type == name]
        dfp.loc[dfp.loss > 1e-3, "loss"] = 1e-3
        print("dfp: ", dfp)

        dfp = dfp.pivot(index="sgd_momentum", columns="batch_size", values="loss")
        print("pivoted: ", dfp)
        
        sns.heatmap(data=dfp, ax=ax) 
        ax.set_title(ax_title)

        fig.tight_layout()
        fig.savefig(os.path.join(outpath, f"{savefile}-{name}.png"))
        fig.savefig(os.path.join(outpath, f"{savefile}-{name}.pdf"))

def test_momentum_and_model(args: Namespace) -> None:
    generic_test(
        outpath=f"{PREFIX}momentum-and-model",
        ax_title=f"({args.model}) SGD momentum: {args.sgd_momentum}",
        savefile=f"{args.model}-{args.sgd_momentum}-convergence"
    )


def test_batch_size(args: Namespace) -> None:
    generic_test(
        outpath=f"{PREFIX}batch-size",
        ax_title=f"({args.model}) batch size: {args.batch_size}",
        savefile=f"{args.model}-{args.batch_size}-convergence"
    )


def test_teacher_momentum(args: Namespace) -> None:
    generic_test(
        outpath=f"{PREFIX}teacher-momentum",
        ax_title=f"({args.model}) teacher momentum: {args.teacher_momentum}",
        savefile=f"{args.model}-{args.teacher_momentum}-convergence"
    )


def test_teacher_momentum_schedule(args:Namespace) -> None:
    generic_test(
        outpath=f"{PREFIX}teacher-momentum-schedule",
        ax_title=f"({args.model}) teacher momentum schedule: {args.teacher_momentum}",
        savefile=f"{args.model}-{args.teacher_momentum_schedule}-convergence"
    )


def test_lr(args:Namespace) -> None:
    generic_test(
        outpath=f"{PREFIX}lr",
        ax_title=f"({args.model}) SGD lr: {args.lr}",
        savefile=f"{args.model}-{args.lr}-convergence"
    )


def test_masking_ratio(args:Namespace) -> None:
    generic_test(
        outpath=f"{PREFIX}masking-ratio",
        ax_title=f"({args.model}) masking ratio: {args.masking_ratio}",
        savefile=f"{args.model}-{args.masking_ratio}-convergence"
    )


def test_gradient_momentum(args: Namespace) -> None:
    outpath = f"{PREFIX}gradient-norm"
    os.makedirs(outpath, exist_ok=True)

    loss_df_path = os.path.join(outpath, "loss-df.csv")
    norm_df_path = os.path.join(outpath, "norm-df.csv")
    cos_df_path = os.path.join(outpath, "cos-df.csv")

    norm_df, loss_df, cos_df = get_csv(norm_df_path), get_csv(loss_df_path), get_csv(cos_df_path)
    if norm_df is None or loss_df is None or cos_df is None:
        loss_df = {"loss": [], "type": [], "iter": []}
        norm_df = {"norm": [], "input_type": [], "grad_type": [], "name": [], "iter": []}
        cos_df = {"cos": [], "input_type": [], "grad_type": [], "iter": []}
        # bound_df = {"bound": [], "input_type": [], "grad_type": [], "iter": []}

        for run, (train_data, val_data) in enumerate(args.run_data):
            print(f"{args.sgd_momentum=} {run=}")
            S = {"linear": nn.Linear(args.dim, args.dim), "deep": AutoEncoder(n_layers=1, in_dim=args.dim, h_dim=args.dim // 2)}[args.model]

            names = ["MAE+Teacher", "MAE"]
            funcs = [
                partial(train_recon_cons, train_data, val_data, copy.deepcopy(S), copy.deepcopy(S), args=args, do_grad_experiment=True),
                partial(train_recon, train_data, val_data, copy.deepcopy(S), args=args, do_grad_experiment=True)
            ]

            for (name, func) in zip(names, funcs):
                losses, s_norms, t_norms, cos_s, cos_t = func()
                # the _norm returns contain a spectral norm over 

                for grad_type, tns in zip(["student", "teacher"], [s_norms, t_norms]):
                    if tns.numel() != 0:
                        for i, typ in enumerate(["same", "similar", "different"]):
                            norm_df["norm"].extend(tns[:, i].tolist())
                            norm_df["input_type"].extend([typ for _ in range(tns.size(0))])
                            norm_df["grad_type"].extend([grad_type for _ in range(tns.size(0))])
                            norm_df["name"].extend([name for _ in range(tns.size(0))])
                            norm_df["iter"].extend([i for i in range(tns.size(0))])

                if name == "MAE+Teacher":
                    for grad_type, tns in zip(["student", "teacher"], [cos_s, cos_t]):
                        if tns.numel() != 0:
                            for i, typ in enumerate(["same", "similar", "different"]):
                                cos_df["cos"].extend(tns[:, i].tolist())
                                cos_df["input_type"].extend([typ for _ in range(tns.size(0))])
                                cos_df["grad_type"].extend([grad_type for _ in range(tns.size(0))])
                                cos_df["iter"].extend([i for i in range(tns.size(0))])

                loss_df["loss"].extend(losses)
                loss_df["type"].extend([name] * len(losses))
                loss_df["iter"].extend([i for i in range(len(losses))])

        loss_df = pd.DataFrame(loss_df)
        norm_df = pd.DataFrame(norm_df)
        cos_df = pd.DataFrame(cos_df)
        loss_df.to_csv(loss_df_path)
        norm_df.to_csv(norm_df_path)
        cos_df.to_csv(cos_df_path)

    hue_order = ["same", "similar", "different"]
    rows, cols = 1, 1

    # plot the cos lines for the student teacher grads on the MAE+Teacher Model
    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))

    for t in ["different", "same", "similar"]:
        for s in ["student", "teacher"]:
            # select only the columns we care about, calculate rolling mean
            sub_df = cos_df[(cos_df.input_type == t) & (cos_df.grad_type == s)]
            sub_df = sub_df[["cos"]].rolling(20).mean()
            
            # set the rolling mean back in the original
            cos_df.loc[(cos_df.input_type == t) & (cos_df.grad_type == s), "cos"] = sub_df["cos"]

    q = cos_df
    print("MAE Teacher student")
    print(q)
    sns.lineplot(data=q, x="iter", y="cos", hue="input_type", style="grad_type", ax=ax, hue_order=hue_order) 
    ax.set_title("Linear MAE+Teacher")
    ax.set(ylabel="cosine similarity")
    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "rc-mae-cos-grads.png"))
    fig.savefig(os.path.join(outpath, "rc-mae-cos-grads.pdf"))

    plt.close()

    # plot the losses for both models
    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))

    sns.lineplot(data=loss_df, x="iter", y="loss", hue="type", style="type", ax=ax) 
    ax.set_title("Convergence")
    ax.set(ylabel="reconstruction loss")
    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "convergence.png"))
    fig.savefig(os.path.join(outpath, "convergence.pdf"))

    plt.close()

    # plot the norm lines for the MAE model
    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))

    q = norm_df[norm_df.name == "MAE"]
    print("MAE")
    print(q)
    sns.lineplot(data=q, x="iter", y="norm", hue="input_type", ax=ax, hue_order=hue_order) 
    ax.set_title("MAE (recon. loss)")
    ax.set(ylabel=r"$\Vert \nabla \Vert_2$")
    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "mae-gradient-norm.png"))
    fig.savefig(os.path.join(outpath, "mae-gradient-norm.pdf"))

    plt.close()

    # plot the norm lines for the student on the MAE+Teacher Model
    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))

    q = norm_df[(norm_df.name == "MAE+Teacher") & (norm_df.grad_type == "student")]
    print("MAE Teacher student")
    print(q)
    sns.lineplot(data=q, x="iter", y="norm", hue="input_type", ax=ax, hue_order=hue_order) 
    ax.set_title("Linear MAE+Teacher (recon.)")
    ax.set(ylabel=r"$\Vert \nabla \Vert_2$")
    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "rc-mae-recon-gradient-norm.png"))
    fig.savefig(os.path.join(outpath, "rc-mae-recon-gradient-norm.pdf"))

    plt.close()

    # plot the norm lines for the teacher on the MAE+Teacher Model
    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))

    q = norm_df[(norm_df.name == "MAE+Teacher") & (norm_df.grad_type == "teacher")]
    print("MAE Teacher teacher")
    print(q)
    sns.lineplot(data=q, x="iter", y="norm", hue="input_type", ax=ax, hue_order=hue_order) 
    ax.set_title("Linear MAE+Teacher (cons.)")
    ax.set(ylabel=r"$\Vert \nabla \Vert_2$")
    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "rc-mae-cons-gradient-norm.png"))
    fig.savefig(os.path.join(outpath, "rc-mae-cons-gradient-norm.pdf"))
    plt.close()

    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))

    q = norm_df[(norm_df.name == "MAE+Teacher") & (norm_df.grad_type == "teacher") & (norm_df.iter > 400)]
    print("MAE Teacher teacher")
    print(q)
    sns.lineplot(data=q, x="iter", y="norm", hue="input_type", ax=ax, hue_order=hue_order) 
    ax.set_title("MAE+Teacher (cons. loss)")
    ax.set(ylabel="gradient norm")
    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "rc-mae-cons-end-gradient-norm.png"))
    fig.savefig(os.path.join(outpath, "rc-mae-cons-end-gradient-norm.pdf"))
    plt.close()

    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))

    q = norm_df[(norm_df.name == "MAE+Teacher") & (norm_df.iter > 400)]
    print("MAE Teacher teacher")
    print(q)
    sns.lineplot(data=q, x="iter", y="norm", hue="input_type", style="grad_type", ax=ax, hue_order=hue_order) 
    ax.set_title("MAE+Teacher (recon. and cons. loss)")
    ax.set(ylabel="gradient norm", ylim=(0, 10))
    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "rc-mae-recon-cons-end-gradient-norm.png"))
    fig.savefig(os.path.join(outpath, "rc-mae-recon-cons-end-gradient-norm.pdf"))
    plt.close()

    fig, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(8 * cols, 6 * rows))
    q = norm_df[(norm_df.name == "MAE+Teacher")]
    print("MAE Teacher teacher")
    print(q)
    sns.lineplot(data=q, x="iter", y="norm", hue="input_type", style="grad_type", ax=ax, hue_order=hue_order) 
    ax.set_title("MAE+Teacher (recon. and cons. loss)")
    ax.set(ylabel="gradient norm")
    fig.tight_layout()
    fig.savefig(os.path.join(outpath, "rc-mae-recon-cons-gradient-norm.png"))
    fig.savefig(os.path.join(outpath, "rc-mae-recon-cons-gradient-norm.pdf"))
    plt.close()
    

if __name__ == "__main__":
    parser = ArgumentParser("Linear Ablation Study")

    parser.add_argument("--runs", type=int, default=5, help="run number")
    parser.add_argument("--test", type=str, default="", help="the test name to run")
    parser.add_argument("--iters", type=int, default=500, help="number of iterations")
    parser.add_argument("--plot-data", action="store_true", help="whether or not to plot data")
    parser.add_argument("--n-per-class", type=int, default=200, help="number of instances per Gaussian")
    parser.add_argument("--classes", type=int, default=10, help="number of instances per Gaussian")
    parser.add_argument("--dim", type=int, default=32, help="number of instances per Gaussian")
    parser.add_argument("--outpath", type=str, default="", help="the outpath to save plots")
    parser.add_argument("--model", type=str, choices=["linear", "deep"], help="the model choice")
    parser.add_argument("--sgd-momentum", type=float, default=0.99, help="the momentum for SGD")
    parser.add_argument("--batch-size", type=int, default=32, help="number of instances per Gaussian")
    parser.add_argument("--teacher-momentum", type=float, default=0.5, help="number of instances per Gaussian")
    parser.add_argument("--teacher-momentum-schedule", type=str, default="", choices=["", "linear", "cosine"], help="the momentum scheduler to use")
    parser.add_argument("--lr", type=float, default=0.01, help="SGD learning rate")
    parser.add_argument("--masking-ratio", type=float, default=0.5, help="SGD learning rate")

    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        level="INFO",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    args.logger = logging.getLogger()

    if args.plot_data:
        os.makedirs("data", exist_ok=True)
        plot_data(outpath="data")
        args.logger.info("plotted data, exiting")
        exit()

    # call all of these first to make sure there are not PSD errors
    run_data, true_dists = [], []  # type: ignore
    for run in range(args.runs):
        data_train, data_val, td = get_data(args.n_per_class, args.dim, n_dists=args.classes, _seed=run)
        run_data.append((data_train, data_val))
        true_dists.append(td)
    args.run_data = run_data
    args.true_dists = true_dists

    tests = {
        "momentum-and-model": test_momentum_and_model, 
        "batch-size": test_batch_size,
        "teacher-momentum": test_teacher_momentum,
        "teacher-momentum-schedule": test_teacher_momentum_schedule,
        "lr": test_lr,
        "masking-ratio": test_masking_ratio,
        "test-gradient-momentum": test_gradient_momentum,
        "test-heatmaps": test_heatmaps
    }

    tests[args.test](args)
