import sys
import os

SAFARI_PATH = "/mnt/safari-internal/"
sys.path.append(SAFARI_PATH)
import yaml
from pathlib import Path
import torch
from torch.utils.checkpoint import checkpoint
from torch.profiler import profile, record_function, ProfilerActivity

from functools import partial
from src.models.sequence.hyena import HyenaOperator
from src.utils.profiling import *
from src.models.sequence.long_conv_lm import ConvLMHeadModel
from src.utils.profiling import benchmark_forward, benchmark_backward

import matplotlib.pyplot as plt
import pandas as pd

config_path = os.path.join(SAFARI_PATH, "configs", "evals", "hyena_dna_512ksl.yaml")
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)

SEQ_LEN = config["model_config"]["layer"]["l_max"]
device_id = 0
batch_size = 1
dtype = torch.bfloat16
device = torch.device(f"cuda:{device_id}")


class LMNoLogits(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)[0].logits


model = ConvLMHeadModel(**config["model_config"]).to(device)
model = LMNoLogits(model)


class Wrapper(torch.nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer

    def forward(self, x):
        return checkpoint(self.layer, x)


mixer = model.model.backbone.layers[0].mixer

# check if the gradients between checkpointed and regular are the same
x = torch.randint(0, 4, (batch_size, SEQ_LEN), device=device).long()
y = model(x)
y.mean().backward()

grads = [p.grad for p in mixer.parameters() if p.requires_grad]

# measure regular runtime
m, t = benchmark_forward(model, x, repeats=3, desc="", verbose=False)
m, tbwd = benchmark_backward(model, x, repeats=3, desc="", verbose=False)
print(f"Regular forward: {t.mean / 1000}ms")
print(f"Regular backward: {tbwd.mean / 1000}ms")

# log per-layer memory usage
log_mem = log_memory(model, x)
df = pd.DataFrame(log_mem)
print(df)
# save to csv
df.to_csv("mem_log_nc.csv")


wrapper = Wrapper(model.model.backbone.layers[0].mixer)
model.model.backbone.layers[0].mixer = wrapper
y = model(x)
y.mean().backward()
grads2 = [p.grad for p in wrapper.parameters() if p.requires_grad]

m, t = benchmark_forward(model, x, repeats=3, desc="", verbose=False)
m, tbwd = benchmark_backward(model, x, repeats=3, desc="", verbose=False)
print(f"Checkpointed forward: {t.mean / 1000}ms")
print(f"Checkpointed backward: {tbwd.mean / 1000}ms")

# log per-layer memory usage
log_mem = log_memory(model, x)
df = pd.DataFrame(log_mem)
print(df)
# save to csv
df.to_csv("mem_log_c.csv")


for g1, g2 in zip(grads, grads2):
    print(torch.allclose(g1, g2))
