# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
from copy import deepcopy

import mmengine
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.evaluator import DumpResults
from mmengine.runner import Runner


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMPreTrain test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--work-dir',
        help='the directory to save the file containing evaluation metrics')
    parser.add_argument('--out', help='the file to output results.')
    parser.add_argument(
        '--out-item',
        choices=['metrics', 'pred'],
        help='To output whether metrics or predictions. '
        'Defaults to output predictions.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--amp',
        action='store_true',
        help='enable automatic-mixed-precision test')
    parser.add_argument(
        '--show-dir',
        help='directory where the visualization images will be saved.')
    parser.add_argument(
        '--show',
        action='store_true',
        help='whether to display the prediction results in a window.')
    parser.add_argument(
        '--interval',
        type=int,
        default=1,
        help='visualize per interval samples.')
    parser.add_argument(
        '--wait-time',
        type=float,
        default=2,
        help='display time of every window. (second)')
    parser.add_argument(
        '--no-pin-memory',
        action='store_true',
        help='whether to disable the pin_memory option in dataloaders.')
    parser.add_argument(
        '--tta',
        action='store_true',
        help='Whether to enable the Test-Time-Aug (TTA). If the config file '
        'has `tta_pipeline` and `tta_model` fields, use them to determine the '
        'TTA transforms and how to merge the TTA results. Otherwise, use flip '
        'TTA by averaging classification score.')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')

    # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
    # will pass the `--local-rank` parameter to `tools/train.py` instead
    # of `--local_rank`.
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args


def merge_args(cfg, args):
    """Merge CLI arguments to config."""
    cfg.launcher = args.launcher

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])

    cfg.load_from = args.checkpoint

    # enable automatic-mixed-precision test
    if args.amp:
        cfg.test_cfg.fp16 = True

    # -------------------- visualization --------------------
    if args.show or (args.show_dir is not None):
        assert 'visualization' in cfg.default_hooks, \
            'VisualizationHook is not set in the `default_hooks` field of ' \
            'config. Please set `visualization=dict(type="VisualizationHook")`'

        cfg.default_hooks.visualization.enable = True
        cfg.default_hooks.visualization.show = args.show
        cfg.default_hooks.visualization.wait_time = args.wait_time
        cfg.default_hooks.visualization.out_dir = args.show_dir
        cfg.default_hooks.visualization.interval = args.interval

    # -------------------- TTA related args --------------------
    if args.tta:
        if 'tta_model' not in cfg:
            cfg.tta_model = dict(type='mmpretrain.AverageClsScoreTTA')
        if 'tta_pipeline' not in cfg:
            test_pipeline = cfg.test_dataloader.dataset.pipeline
            cfg.tta_pipeline = deepcopy(test_pipeline)
            flip_tta = dict(
                type='TestTimeAug',
                transforms=[
                    [
                        dict(type='RandomFlip', prob=1.),
                        dict(type='RandomFlip', prob=0.)
                    ],
                    [test_pipeline[-1]],
                ])
            cfg.tta_pipeline[-1] = flip_tta
        cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
        cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline

    # ----------------- Default dataloader args -----------------
    default_dataloader_cfg = ConfigDict(
        pin_memory=True,
        collate_fn=dict(type='default_collate'),
    )

    def set_default_dataloader_cfg(cfg, field):
        if cfg.get(field, None) is None:
            return
        if isinstance(cfg.get(field, None), list): #added by yoshimura
            return


import argparse
import os
import os.path as osp
from copy import deepcopy
import torch
import numpy as np
import matplotlib.pyplot as plt
import pickle

from mmengine.config import Config
from mmengine.runner import Runner
from tools.mmpretrain_tools.test import parse_args, merge_args


def save_model_state(model, filename):
    """Save model state dictionary to file"""
    torch.save(model.state_dict(), filename)


def load_model_state(model, filename):
    """Load model state dictionary from file"""
    state_dict = torch.load(filename)
    model.load_state_dict(state_dict)
    return model


def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    cfg = merge_args(cfg, args)

    runner = Runner.from_cfg(cfg)
    model = runner.model
    model.eval()

    # Run standard evaluation on the original model
    print("=== Running standard evaluation on original model ===")
    metrics = runner.test()
    print(f"Original metrics: {metrics}")

    # Extract F1 score and accuracy
    orig_acc = metrics.get('accuracy/top1', 0.0)
    orig_f1 = metrics.get('single-label/f1-score', 0.0)
    orig_f1_loss = 1.0 - orig_f1 / 100.0  # Convert percentage to [0,1] range

    print(f"Original model - F1: {orig_f1:.4f}, F1 Loss: {orig_f1_loss:.4f}, Accuracy: {orig_acc:.2f}%")

    # Save original weights
    orig_state_dict = deepcopy(model.state_dict())

    # Directory to save perturbed models
    perturbed_dir = 'perturbed_models'
    os.makedirs(perturbed_dir, exist_ok=True)

    # Save original model as baseline
    orig_model_path = os.path.join(perturbed_dir, "original_model.pth")
    save_model_state(model, orig_model_path)

    # Get all trainable parameters
    orig_params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            orig_params.append((name, param.data.clone(), param))

    # Generate random directions
    direction1 = [torch.randn_like(p) for _, _, p in orig_params]
    direction2 = [torch.randn_like(p) for _, _, p in orig_params]

    # Normalize directions
    d1_norm = np.sqrt(sum(torch.sum(d * d).item() for d in direction1))
    d2_norm = np.sqrt(sum(torch.sum(d * d).item() for d in direction2))

    direction1 = [d / d1_norm for d in direction1]
    direction2 = [d / d2_norm for d in direction2]

    # Use a smaller range for perturbations
    alphas = np.linspace(-100, 100, 10)  # Smaller grid for faster evaluation
    betas = np.linspace(-100, 100, 10)  # and smaller perturbation magnitude
    f1_loss_surface = np.zeros((len(alphas), len(betas)))
    acc_surface = np.zeros((len(alphas), len(betas)))

    landscape_cache = 'f1_loss_landscape_reload.pkl'
    if os.path.exists(landscape_cache):
        print(f"[INFO] Loading cached landscapes from {landscape_cache}")
        with open(landscape_cache, 'rb') as f:
            cache = pickle.load(f)
            alphas = cache['alphas']
            betas = cache['betas']
            f1_loss_surface = cache['f1_loss']
            acc_surface = cache['accuracy']
    else:
        print("[INFO] Computing F1 loss landscape with model reloading")

        # First, create and save all perturbed models
        print("[Phase 1] Creating and saving perturbed models...")
        for i, alpha in enumerate(alphas):
            for j, beta in enumerate(betas):
                # Skip (0,0) point since we already saved the original model
                if alpha == 0.0 and beta == 0.0:
                    continue

                # # When creating perturbed models
                # print(f"  Creating model with α={alpha:.4f}, β={beta:.4f}")
                #
                # # Store some original parameter values for verification
                # sample_params = []
                # for name, param in model.named_parameters():
                #     if param.requires_grad:
                #         # Take first few values from each parameter
                #         sample_params.append((name, param.data[:5].clone().cpu().numpy()))
                #         break  # Just check one parameter tensor for simplicity
                #
                # # Apply perturbation
                # for (_, base, param), d1, d2 in zip(orig_params, direction1, direction2):
                #     param.data = base + alpha * d1 + beta * d2
                #
                # # Verify perturbation on the sample parameters
                # for name, orig_values in sample_params:
                #     param = dict(model.named_parameters())[name]
                #     perturbed_values = param.data[:5].cpu().numpy()
                #     print(f"  Parameter {name}:")
                #     print(f"    Original: {orig_values}")
                #     print(f"    Perturbed: {perturbed_values}")
                #     print(f"    Difference: {perturbed_values - orig_values}")

                print(f"  Creating model with α={alpha:.4f}, β={beta:.4f}")

                # Start from original weights
                model.load_state_dict(orig_state_dict)

                # Apply perturbation
                for (_, base, param), d1, d2 in zip(orig_params, direction1, direction2):
                    param.data = base + alpha * d1 + beta * d2

                # Save perturbed model
                perturbed_path = os.path.join(perturbed_dir, f"perturbed_a{alpha:.4f}_b{beta:.4f}.pth")
                save_model_state(model, perturbed_path)

        # Now, load and evaluate each model
        print("[Phase 2] Loading and evaluating each model...")
        for i, alpha in enumerate(alphas):
            for j, beta in enumerate(betas):
                print(f"Evaluating α={alpha:.4f}, β={beta:.4f}")

                # Center point (original model)
                if alpha == 0.0 and beta == 0.0:
                    print("  Center point (original model)")
                    f1_loss_surface[i, j] = orig_f1_loss
                    acc_surface[i, j] = orig_acc
                    continue

                # Load perturbed model
                perturbed_path = os.path.join(perturbed_dir, f"perturbed_a{alpha:.4f}_b{beta:.4f}.pth")
                model.load_state_dict(torch.load(perturbed_path))

                # Evaluate using standard evaluation
                try:
                    metrics = runner.test()

                    # Extract F1 score and accuracy
                    acc = metrics.get('accuracy/top1', 0.0)
                    f1 = metrics.get('single-label/f1-score', 0.0)
                    f1_loss = 100.0 - f1 / 100.0  # Convert percentage to [0,1] range

                    print(f"  F1: {f1:.4f}, F1 Loss: {f1_loss:.4f}, Accuracy: {acc:.2f}%")

                    f1_loss_surface[i, j] = f1_loss
                    acc_surface[i, j] = acc
                except Exception as e:
                    print(f"  Error in evaluation: {e}")
                    f1_loss_surface[i, j] = float('nan')
                    acc_surface[i, j] = float('nan')

        # Save cache
        with open(landscape_cache, 'wb') as f:
            pickle.dump({
                'alphas': alphas,
                'betas': betas,
                'f1_loss': f1_loss_surface,
                'accuracy': acc_surface,
                'original_f1_loss': orig_f1_loss,
                'original_accuracy': orig_acc
            }, f)
        print(f"[INFO] Saved landscapes to {landscape_cache}")

        # Restore original model
        model.load_state_dict(orig_state_dict)

    # Filter out NaN values for visualization
    f1_loss_surface_masked = np.ma.masked_invalid(f1_loss_surface)
    acc_surface_masked = np.ma.masked_invalid(acc_surface)

    f1_loss_surface_masked = 1- (f1_loss_surface_masked * 100)
    f1_loss_surface_masked = (100 - f1_loss_surface_masked) / 100

    # Create visualizations
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))

    # F1 Loss landscape
    cp1 = ax1.contourf(alphas, betas, f1_loss_surface_masked, levels=30, cmap='viridis')
    fig.colorbar(cp1, ax=ax1, label="F1 Loss (1-F1)")
    ax1.contour(alphas, betas, f1_loss_surface_masked, levels=8, colors='white', alpha=0.5, linewidths=0.5)
    ax1.scatter([0], [0], color='red', marker='*', s=100, label='Original model')
    ax1.set_title("F1 Loss Landscape")
    ax1.set_xlabel("Direction 1 (α)")
    ax1.set_ylabel("Direction 2 (β)")
    ax1.legend()
    ax1.grid(alpha=0.2)

    # Accuracy landscape
    cp2 = ax2.contourf(alphas, betas, acc_surface_masked, levels=30, cmap='plasma')
    fig.colorbar(cp2, ax=ax2, label="Accuracy (%)")
    ax2.contour(alphas, betas, acc_surface_masked, levels=8, colors='white', alpha=0.5, linewidths=0.5)
    ax2.scatter([0], [0], color='red', marker='*', s=100, label='Original model')
    ax2.set_title("Accuracy Landscape")
    ax2.set_xlabel("Direction 1 (α)")
    ax2.set_ylabel("Direction 2 (β)")
    ax2.legend()
    ax2.grid(alpha=0.2)

    plt.tight_layout()
    plt.savefig("f1_loss_landscape_reload.png", dpi=300)

    # Create 3D visualizations
    fig = plt.figure(figsize=(18, 8))

    # 3D F1 Loss landscape
    ax3 = fig.add_subplot(121, projection='3d')
    X, Y = np.meshgrid(alphas, betas)
    surf1 = ax3.plot_surface(X, Y, f1_loss_surface_masked, cmap='viridis', edgecolor='none')
    fig.colorbar(surf1, ax=ax3, label="F1 Loss (1-F1)")
    ax3.set_xlabel('Direction 1 (α)')
    ax3.set_ylabel('Direction 2 (β)')
    ax3.set_zlabel('F1 Loss')
    ax3.set_title('3D F1 Loss Landscape')

    # 3D Accuracy landscape
    ax4 = fig.add_subplot(122, projection='3d')
    surf2 = ax4.plot_surface(X, Y, acc_surface_masked, cmap='plasma', edgecolor='none')
    fig.colorbar(surf2, ax=ax4, label="Accuracy (%)")
    ax4.set_xlabel('Direction 1 (α)')
    ax4.set_ylabel('Direction 2 (β)')
    ax4.set_zlabel('Accuracy (%)')
    ax4.set_title('3D Accuracy Landscape')

    plt.tight_layout()
    plt.savefig("f1_loss_landscape_3d_reload.png", dpi=300)
    plt.show()


if __name__ == '__main__':
    main()

# def main():
#     args = parse_args()
#     cfg = Config.fromfile(args.config)
#     cfg = merge_args(cfg, args)
#
#     # Turn off visual hooks for speed
#     if 'visualization' in cfg.default_hooks:
#         cfg.default_hooks.visualization.enable = False
#
#     # Create runner and model
#     runner = Runner.from_cfg(cfg)
#     model = runner.model
#     model.eval()
#
#     # Get all trainable parameters
#     orig_params = []
#     for n, p in model.named_parameters():
#         if p.requires_grad:
#             orig_params.append((n, p.data.clone(), p))
#
#     # Generate directions
#     direction1 = [torch.randn_like(p) for _, _, p in orig_params]
#     direction2 = [torch.randn_like(p) for _, _, p in orig_params]
#
#     alphas = np.linspace(-1.0, 1.0, 10)
#     betas = np.linspace(-1.0, 1.0, 10)
#     loss_surface = np.zeros((len(alphas), len(betas)))
#
#     loss_cache = 'loss_10level_runner_test.pkl'
#     if os.path.exists(loss_cache):
#         print(f"Loading cached loss surface from {loss_cache}")
#         with open(loss_cache, 'rb') as f:
#             cache = pickle.load(f)
#             alphas = cache['alphas']
#             betas = cache['betas']
#             loss_surface = cache['loss']
#     else:
#         for i, alpha in enumerate(alphas):
#             for j, beta in enumerate(betas):
#                 print(f'alpha={alpha:.2f}, beta={beta:.2f}')
#
#                 # Apply perturbation
#                 for (_, base, param), d1, d2 in zip(orig_params, direction1, direction2):
#                     param.data = base + alpha * d1 + beta * d2
#
#                 # Evaluate using runner.test()
#                 runner.model.eval()
#                 metrics = runner.test()
#                 print(metrics)
#                 loss = metrics.get('loss', None)
#
#                 if loss is None:
#                     print("Warning: 'loss' not found in metrics. Got:", metrics)
#                     loss = np.nan
#
#                 loss_surface[i, j] = loss
#
#         # Save loss cache
#         with open(loss_cache, 'wb') as f:
#             pickle.dump({'alphas': alphas, 'betas': betas, 'loss': loss_surface}, f)
#         print(f"Saved loss surface to {loss_cache}")
#
#         # Reset parameters
#         for _, base, param in orig_params:
#             param.data = base
#
#     # === Plotting ===
#     plt.figure(figsize=(6, 5))
#     cp = plt.contourf(alphas, betas, loss_surface, levels=50, cmap='viridis')
#     plt.colorbar(cp, label="Loss")
#     plt.title('Loss Landscape (runner.test-based)')
#     plt.xlabel('Direction 1 (alpha)')
#     plt.ylabel('Direction 2 (beta)')
#     plt.tight_layout()
#     plt.savefig('loss_10level_runner_test.png', dpi=300)
#     plt.show()
#
#
# if __name__ == '__main__':
#     main()