import pytest
import os
import torch
import torchvision.models as models
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
from nesim.experiments.imagenet import (
    ImageNetTrainingState,
)  # Replace 'your_module' with the actual module containing the class


@pytest.fixture
def resnet18_model():
    return models.resnet18()


@pytest.fixture
def optimizer_scheduler(resnet18_model):
    optimizer = SGD(resnet18_model.parameters(), lr=0.01)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
    return optimizer, scheduler


def assert_matching_state_dict(a: dict, b: dict):
    assert list(a.keys()) == list(
        b.keys()
    ), f"Keys Mismatch: {list(a.keys())} != {list(b.keys())}"

    for value_a, value_b in zip(a.values(), b.values()):
        if torch.is_tensor(value_a) and torch.is_tensor(value_b):
            assert torch.allclose(
                value_a, value_b
            ), f"Value Mismatch: {value_a} != {value_b}"


def test_ImageNetTrainingState_save_load(tmp_path, resnet18_model, optimizer_scheduler):
    # Create an instance of ImageNetTrainingState
    model, optimizer, scheduler = resnet18_model, *optimizer_scheduler
    training_state = ImageNetTrainingState(model, optimizer, scheduler)

    # Save the state
    save_path = os.path.join(tmp_path, "test_training_state.pth")
    training_state.save(save_path)

    # Load the state
    new_training_state = ImageNetTrainingState(model, optimizer, scheduler).load(
        filename=save_path
    )
    model = new_training_state["model"]
    scheduler = new_training_state["scheduler"]
    optimizer = new_training_state["optimizer"]

    assert_matching_state_dict(training_state.model.state_dict(), model.state_dict())
    assert_matching_state_dict(
        training_state.optimizer.state_dict(), optimizer.state_dict()
    )
    assert_matching_state_dict(
        training_state.scheduler.state_dict(), scheduler.state_dict()
    )

    assert training_state.wandb_run_id == new_training_state["wandb_run_id"]
    assert training_state.global_step == new_training_state["global_step"]
    assert (
        training_state.nesim_config_filename
        == new_training_state["nesim_config_filename"]
    )
