import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from mmengine.config import Config
from mmengine.registry import MODELS
from mmengine.runner import Runner, load_checkpoint

# IMPORTANT: Register all MMPreTrain models (including ImageClassifier)
from mmpretrain.models import *


def load_model_and_cfg(config_path, checkpoint_path, dataset_name):
    # Load the config from file
    cfg = Config.fromfile(config_path)

    # Override sub_dataset_name if needed (here we force diabetic_retinopathy)
    cfg.sub_dataset_name = dataset_name
    # Set the correct number of classes (update if necessary)
    NUM_CLASSES = {
        'diabetic_retinopathy': 5,  # Set the number of classes for diabetic_retinopathy
        'cifar': 100  # your original config used 100 classes for cifar
    }
    if dataset_name in NUM_CLASSES:
        cfg.model.head.num_classes = NUM_CLASSES[dataset_name]
        cfg.data_preprocessor.num_classes = NUM_CLASSES[dataset_name]

    # Make sure the test dataloader is set to use the desired sub-dataset
    if hasattr(cfg, 'test_dataloader'):
        if hasattr(cfg.test_dataloader, 'dataset'):
            cfg.test_dataloader.dataset.sub_dataset_name = dataset_name

    # Build the model using MMPreTrain registry
    model = MODELS.build(cfg.model)
    model.cuda()
    model.eval()

    # Load pretrained weights from your checkpoint.
    load_checkpoint(model, checkpoint_path, map_location='cpu')

    return model, cfg


def get_random_directions(model):
    direction1, direction2 = [], []
    for param in model.parameters():
        if param.requires_grad:
            d1 = torch.randn_like(param)
            d2 = torch.randn_like(param)
            # Normalize each direction to have unit norm
            direction1.append(d1 / (torch.norm(d1) + 1e-10))
            direction2.append(d2 / (torch.norm(d2) + 1e-10))
    return direction1, direction2


def perturb_model(base_weights, d1, d2, alpha, beta):
    new_state = {}
    # Iterate over the parameters (assumes order matches between base_weights and directions)
    for (k, v), dd1, dd2 in zip(base_weights.items(), d1, d2):
        if v.dtype == torch.float32:
            new_state[k] = v + alpha * dd1 + beta * dd2
        else:
            new_state[k] = v
    return new_state


@torch.no_grad()
def evaluate_loss(model, dataloader):
    total_loss = 0.0
    count = 0
    device = next(model.parameters()).device

    for i, data in enumerate(dataloader):
        # Move data to device if it is a tensor
        data = {k: (v.to(device) if torch.is_tensor(v) else v)
                for k, v in data.items()}
        # Assume your model defines a loss method that takes the batch as kwargs.
        losses = model.loss(**data)
        loss_val = sum(losses.values()).item()  # sum losses if there are multiple entries
        total_loss += loss_val
        count += 1
        if count >= 10:  # Limit to 10 batches to keep evaluation time short
            break
    return total_loss / count


def main():
    # Update these paths as needed
    config_path = '/gpfs/gibbs/pi/panda/dl2345/Research/MambaPEFT/vision/configs/mmpretrain/vim/vtab1k/1_small/2_small_lorap_Z_64.py'
    checkpoint_path = 'work_dirs/your_run/epoch_100_diabetic_retinopathy.pth'
    dataset_name = 'diabetic_retinopathy'

    model, cfg = load_model_and_cfg(config_path, checkpoint_path, dataset_name)
    runner = Runner.from_cfg(cfg)
    dataloader = runner.test_dataloader

    base_weights = deepcopy(dict(model.named_parameters()))
    d1, d2 = get_random_directions(model)

    # Define grid limits for perturbation (adjust range or steps if needed)
    alphas = np.linspace(-1.0, 1.0, 21)
    betas = np.linspace(-1.0, 1.0, 21)
    losses = np.zeros((len(alphas), len(betas)))

    # Evaluate the loss landscape by perturbing the model parameters
    for i, alpha in enumerate(alphas):
        for j, beta in enumerate(betas):
            perturbed = perturb_model(base_weights, d1, d2, alpha, beta)
            model.load_state_dict(perturbed, strict=False)
            loss = evaluate_loss(model, dataloader)
            losses[i, j] = loss
            print(f"alpha={alpha:.2f}, beta={beta:.2f} → loss={loss:.4f}")

    # Create a meshgrid for plotting
    A, B = np.meshgrid(alphas, betas)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_surface(A, B, losses, cmap='viridis')
    ax.set_xlabel('Alpha')
    ax.set_ylabel('Beta')
    ax.set_zlabel('Loss')
    plt.title('3D Loss Landscape for VisionMamba (Diabetic Retinopathy)')
    plt.colorbar(surf, shrink=0.5, aspect=5)
    plt.tight_layout()
    plt.savefig('loss_landscape_3d.png', dpi=300)
    plt.show()


if __name__ == '__main__':
    main()
