# %%
import importlib
import torchvision
from torchvision.utils import save_image, make_grid
import torch
from einops import rearrange
import matplotlib.pyplot as plt
import numpy as np
import hydra
from omegaconf import OmegaConf

from nn.inr import INR
from experiments.data import INRImageDataset, INRDataset
from experiments.utils import common_parser, make_coordinates
from nn.rt_transformer import GraphProbeFeatures
from experiments.data_nfn import Siren, SIREN_kwargs



def create_object(name, **kwargs):
    module_name, class_name = name.rsplit(".", 1)
    module = importlib.import_module(module_name)
    class_ = getattr(module, class_name)
    return class_(**kwargs)


# %%
trainset = INRDataset(
    path='~/data/INR/mnist_splits.json',
    statistics_path="experiments/mnist/dataset/statistics.pth",
    split="train"
)

train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    # pin_memory=True,
)

batch = next(iter(train_loader))

# %%
coord = make_coordinates((28,28),1)
gpf = GraphProbeFeatures(2, 28 * 28, 1, input_init=coord)

out = gpf(batch.weights, batch.biases)

img = rearrange(out[:,-1].detach(), "b (h w) -> b h w", h=28)

# %%
mnist_ds = torchvision.datasets.MNIST("mnist_data")
original_img = torch.from_numpy(np.array(mnist_ds[3][0]))
fimg = original_img / 255
ckpt = torch.load(trainset.dataset["path"][0], map_location="cpu")

model = INR(in_dim=2, n_layers=3, up_scale=16)
model.load_state_dict(ckpt)
pimg = model(coord)
pimg = rearrange(pimg, "1 (h w) 1 -> h w", h=28)
# %%
plt.imshow(fimg, cmap="gray", vmin=0, vmax=1)
plt.colorbar()
# %%
plt.imshow(pimg.detach(), cmap="gray", vmin=0, vmax=1)
plt.colorbar()
# %%
pimg = model(coord)[0]
pimg = rearrange(pimg, "1 (h w) 1 -> h w", h=28)
# %%
paths = [
    "NFN_data/siren_mnist_wts/randinit_smaller_1s/net3.pth",
    *[f"NFN_data/siren_mnist_wts/randinit_smaller_aug{i}_1s/net3.pth" for i in range(10)]
]
# %%
ckpts = [torch.load(path, map_location="cpu") for path in paths]
# %%
model = Siren(**SIREN_kwargs['mnist'])
imgs = []
for c in ckpts:
    model.load_state_dict(c)
    model.eval()
    pimg = model(coord)[0]
    pimg = rearrange(pimg, "1 (w h) 1 -> h w", h=28)
    imgs.append(pimg.detach())
    # plt.imshow(pimg.detach(), cmap="gray", vmin=0, vmax=1)
    # plt.colorbar()
    # plt.show()
# %%
plt.imshow(imgs[0], cmap="gray", vmin=-1, vmax=1)
plt.colorbar()
# %%
with hydra.initialize(version_base=None, config_path="experiments/mnist/configs"):
    cfg = hydra.compose(config_name="config_style", overrides=["data.path='~/data/INR/mnist_splits.json'"])
    # print(OmegaConf.to_yaml(cfg))
# %%
train_set = create_object(cfg.data.cls, **cfg.data.train)
val_set = create_object(cfg.data.cls, **cfg.data.val)
test_set = create_object(cfg.data.cls, **cfg.data.test)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    # pin_memory=True,
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_set,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    # pin_memory=True,
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    # pin_memory=True,
)

#%%
params, img = next(iter(test_loader))
img = img.squeeze(1)
coord = make_coordinates((28,28),1)
gpf = GraphProbeFeatures(2, 28 * 28, 1, input_init=coord)

out = gpf(params.weights, params.biases)

#%%
dimg = rearrange(out[:,-1].detach(), "b (h w) -> b h w", h=28)
plt.imshow(dimg[0], cmap="gray", vmin=0, vmax=1)
plt.colorbar()

# %%
plt.imshow(img[0], cmap="gray", vmin=-1, vmax=1)
plt.colorbar()