from __future__ import annotations

import argparse
import json
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import optuna
import torch
import torch.nn as nn
import torch.optim as optim

from utils import set_global_seed  # already available
import search_spaces              # already available
from optimizers.signum import Signum, SoftSignum
from optimizers.epsilon import AdamBiasCorrectedEps, Adam
from optimizers.adam_scheduling import AdamPaLM2Beta, AdamBeta2Schedule, AdamEpsilonSchedule
from optimizers.signum_dl import SignumDL
from optimizers.softsign import SoftSignumSGD

from toy_functions import TORCH_FUNCTIONS, BAD_CENTERS


# Visualization parameters for each function
PLOT_PARAMS = {
    "deceptive_landscape": {
        "x_range": (-5, 5),
        "y_range": (-5, 5),
        "elev": 30,
        "azim": -135,
        "resolution": 200,
    },
    "canyon_waterfall": {
        "x_range": (-5, 5),
        "y_range": (-5, 5),
        "elev": 10,
        "azim": 135,
        "resolution": 200,
    },
    "concentric_barriers": {
        "x_range": (-4.6, 4.6),
        "y_range": (-4.6, 4.6),
        "elev": 5,
        "azim": 45,
        "resolution": 200,
    },
    "local_min_plateau_deep_min": {
        "x_range": (-5, 5),
        "y_range": (-5, 5),
        "elev": 20,
        "azim": 90,
        "resolution": 200,
    },
    "plateau_with_traps_deep_min": {
        "x_range": (-4, 4),
        "y_range": (-4, 4),
        "elev": 30,
        "azim": 10,
        "resolution": 500,
    },
    "complex_journey": {
        "x_range": (-5, 5),
        "y_range": (-5, 5),
        "elev": 30,
        "azim": 0,
        "resolution": 500,
    },
}


class Point2D(nn.Module):
    """Optimizable parameter z ∈ R^2."""
    def __init__(self, init_xy: torch.Tensor):
        super().__init__()
        self.z = nn.Parameter(init_xy.clone())

    def forward(self) -> torch.Tensor:
        return self.z


def suggest_params(trial: optuna.Trial, search_space: Dict) -> Dict:
    params = {}
    for param in search_space:
        if isinstance(search_space[param], dict):
            spec = search_space[param]
            if spec["type"] == "float":
                params[param] = trial.suggest_float(param, spec["min"], spec["max"], log=spec["log"])
            else:
                params[param] = trial.suggest_int(param, spec["min"], spec["max"], log=spec["log"])
    return params


def get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=None, n_iters=None):
    clip = None
    scheduler = None
    if trial is None and optimizer_params is None:
        raise ValueError("Params and trial can not be None together")
    if trial is not None:
        optimizer_params = suggest_params(trial, search_space)
    
    if optimizer_name == 'AdamW':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
    elif optimizer_name == 'Signum':
        optimizer = Signum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
        )
    elif optimizer_name == 'SignumLinearLR':
        optimizer = Signum(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'Signum_decoupled_wd':
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=n_iters,
            decoupled_wd=True,
            hook=hook
        )
    elif optimizer_name == 'Signum_decoupled_wd_LinearLR':
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=n_iters,
            decoupled_wd=True,
            hook=hook
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'SignumDL':
        optimizer = SignumDL(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
        )
    elif optimizer_name == 'SignumDLNesterov':
        optimizer = SignumDL(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            nesterov=True,
        )
    elif optimizer_name == 'SoftSignum':
        optimizer = SoftSignum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
        )
    elif optimizer_name == 'SoftSignum_decoupled_wd':
        optimizer = SoftSignum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            decoupled_wd=True
        )
    elif optimizer_name == 'SoftSignumSGD': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD_not_decoupled_wd': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook,
            decoupled_wd=False
        )
    elif optimizer_name == 'SoftSignumSGD-auto': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD_not_decoupled_wd-auto': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD-const': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD_not_decoupled_wd-const': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook,
            decoupled_wd=False
        )
    elif optimizer_name == 'Signum+SGD': # Only signum iters -> sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            # tmin=optimizer_params['tmin'], # not used if warmup_iters = 0
            # tmax=optimizer_params['tmax'],
            warmup_iters=0,
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            sgd_last=True,
            hook=hook
        )
    elif optimizer_name == 'Signum+SGD_not_decoupled_wd': # Only signum iters -> sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            # tmin=optimizer_params['tmin'], # not used if warmup_iters = 0
            # tmax=optimizer_params['tmax'],
            warmup_iters=0,
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            sgd_last=True,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT-auto': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT-const': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd-auto': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd-const': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
    # This optimizer seems to be a special case of the optimizer above ^
    # elif optimizer_name == 'SoftSignumPT':
    #     optimizer = SoftSignumSGD(
    #         model.parameters(), 
    #         lr=optimizer_params['lr'], 
    #         momentum=optimizer_params['momentum'],
    #         weight_decay=optimizer_params['weight_decay'],
    #         tmin=optimizer_params['tmin'],
    #         tmax=optimizer_params['tmax'],
    #         warmup_iters=optimizer_params['warmup_iters'],
    #     )
    elif optimizer_name == 'Adam':
        optimizer = Adam(
            model.parameters(), 
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            eps=optimizer_params['eps']
        )
    elif optimizer_name == 'AdamPaLM2':
        optimizer = AdamPaLM2Beta(
            model.parameters(), 
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            betas=(optimizer_params['beta1'], optimizer_params['beta2']),
            beta2_final=optimizer_params['beta2_final']
        )
    elif optimizer_name == 'AdamBetaScheduling':
        optimizer = AdamBeta2Schedule(
            model.parameters(),
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=optimizer_params['warmup_iters'],
        )
    elif optimizer_name == 'AdamEpsScheduling':
        optimizer = AdamEpsilonSchedule(
            model.parameters(),
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=optimizer_params['warmup_iters'],
        )
    elif optimizer_name == 'AdamEps':
        optimizer = AdamBiasCorrectedEps(
            model.parameters(), 
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            eps=optimizer_params['eps']
        )
    elif optimizer_name == 'AdamWClip':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'SGDClip':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'SGD':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
    elif optimizer_name == 'SGDLinearLR':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            weight_decay=optimizer_params['weight_decay']
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'SGDLinearLR+Clip':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            weight_decay=optimizer_params['weight_decay']
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'SGDCosineAnnealingLR':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            weight_decay=optimizer_params['weight_decay']
        )
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            eta_min=optimizer_params['eta_min'],
            T_max=int(optimizer_params['schedule_iters'] * n_iters)
        )
    else:
        raise NotImplementedError(f"There is no optimizer {optimizer_name} yet")
    return optimizer, (clip, scheduler)


@torch.no_grad()
def _finite_or_inf(x: torch.Tensor) -> float:
    v = x.detach().float().cpu()
    if torch.isfinite(v).all():
        return float(v.item())
    return float("inf")


def run_one_start(f_torch, optimizer_name, search_space, optimizer_params, device, steps, init_xy):
    model = Point2D(init_xy.to(device)).to(device)
    n_iters = steps

    optimizer, (clipping, scheduler) = get_optimizer(
        optimizer_name=optimizer_name,
        model=model,
        search_space=search_space,
        trial=None,
        optimizer_params=optimizer_params,
        n_iters=n_iters,
    )

    best = float("inf")

    for _ in range(steps):
        optimizer.zero_grad(set_to_none=True)

        z = model()                      # shape (2,)
        loss_tensor = f_torch(z.unsqueeze(0))  # (1,) or (1,1)
        if loss_tensor.ndim == 2 and loss_tensor.shape[-1] == 1:
            loss_tensor = loss_tensor.squeeze(-1)
        loss = loss_tensor.squeeze(0)  # scalar tensor

        if not torch.isfinite(loss):
            return float("inf"), float("inf")

        loss.backward()

        if clipping is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type="inf")

        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        lv = float(loss.detach().cpu().item())
        if lv < best:
            best = lv

    final_z = model()
    final_loss_tensor = f_torch(final_z.unsqueeze(0))  # (1,) or (1,1)
    if final_loss_tensor.ndim == 2 and final_loss_tensor.shape[-1] == 1:
        final_loss_tensor = final_loss_tensor.squeeze(-1)
    final = float(final_loss_tensor.squeeze(0).detach().cpu().item())
    return best, final


def tune(
    n_trials: int,
    search_space: Dict,
    optimizer_name: str,
    function_name: str,
    device: str,
    steps: int,
    n_jitters: int,
    epsilon: float,
    n_startup_trials: int,
    path: str,
    seed: int = 42,
):
    f_torch = TORCH_FUNCTIONS[function_name]
    center_xy = BAD_CENTERS[function_name]  # (x0, y0)
    
    # Set seed once before creating study for determinism
    set_global_seed(seed)

    def objective(trial: optuna.Trial) -> float:
        opt_params = suggest_params(trial, search_space)

        # Fix the same set of jitter-starts for all trials (fair comparison)
        g = torch.Generator().manual_seed(seed)
        center = torch.tensor(center_xy, dtype=torch.float32)

        # Normal jitter (can be replaced with uniform if desired)
        jit = torch.randn((n_jitters, 2), generator=g) * float(epsilon)
        inits = center.unsqueeze(0) + jit  # (K,2)

        best_losses: List[float] = []
        for i in range(n_jitters):
            best_loss, _ = run_one_start(
                f_torch=f_torch,
                optimizer_name=optimizer_name,
                search_space=search_space,
                optimizer_params=opt_params,
                device=device,
                steps=steps,
                init_xy=inits[i],
            )
            best_losses.append(best_loss)
            
        score = float(np.mean(np.array(best_losses, dtype=np.float64)))

        # pruner-friendly hook (one progress point)
        trial.report(score, step=0)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        return score  # minimize

    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials)

    os.makedirs(os.path.join(path, "study"), exist_ok=True)
    db_filename = f'{function_name}__{optimizer_name}.db'
    storage_path = f"sqlite:///{os.path.abspath(os.path.join(path, 'study', db_filename))}"

    study = optuna.create_study(
        study_name=None,  # None = uses name from storage, avoiding conflicts between files
        storage=storage_path,
        direction="minimize",
        pruner=pruner,
        sampler=sampler,
        load_if_exists=True,
    )
    
    # Warning if study already contains trials
    if len(study.trials) > 0:
        print(f"Warning: Continuing existing study with {len(study.trials)} previous trials. "
              f"New trials will start from {len(study.trials)}.")
    
    study.optimize(objective, n_trials=n_trials)

    best_trial = study.best_trial
    result = dict(best_trial.params)
    result.update(
        {
            "function": function_name,
            "optimizer": optimizer_name,
            "best_objective": float(best_trial.value),
            "trial": int(best_trial.number),
            "steps": int(steps),
            "n_jitters": int(n_jitters),
            "epsilon": float(epsilon),
            "bad_center": [float(center_xy[0]), float(center_xy[1])],
            "seed": int(seed),
        }
    )

    # os.makedirs(path, exist_ok=True)
    # out_json = os.path.join(path, f"{function_name}__{optimizer_name}.json")
    # with open(out_json, "w") as f:
    #     json.dump(result, f, indent=2)

    return result


def get_arguments():
    parser = argparse.ArgumentParser("2D function hyperparameter tuning with Optuna (hard-region jitter)")

    parser.add_argument("--function", type=str, required=True, choices=sorted(TORCH_FUNCTIONS.keys()))
    parser.add_argument(
        "--optimizer",
        type=str,
        nargs='+',
        default=["AdamW"],
        choices=[
            "SignumDL", "SignumDLNesterov",
            "Signum", "SignumLinearLR",
            "Signum_decoupled_wd", "Signum_decoupled_wd_LinearLR",
            "AdamW", "Adam", "AdamEps",
            "AdamWClip", "AdamEpsScheduling", "AdamBetaScheduling", "AdamPaLM2",
            "SGD", "SGDLinearLR", "SGDCosineAnnealingLR", "SGDClip", "SGDLinearLR+Clip",
            "SoftSignum", "SoftSignum_decoupled_wd",
            "SoftSignumSGD", "SoftSignumSGD-auto", "SoftSignumSGD-const",
            "SoftSignumSGD_not_decoupled_wd", "SoftSignumSGD_not_decoupled_wd-auto", "SoftSignumSGD_not_decoupled_wd-const",
            "Signum+SGD", "Signum+SGD_not_decoupled_wd",
            "SoftSignumPT", "SoftSignumPT-auto", "SoftSignumPT-const",
            "SoftSignumPT_not_decoupled_wd", "SoftSignumPT_not_decoupled_wd-auto", "SoftSignumPT_not_decoupled_wd-const",
        ],
    )

    parser.add_argument("--user", type=str, required=True, choices=["user1", "user2", "user3"])
    parser.add_argument("--n_trials", type=int, default=50)
    parser.add_argument("--n_startup_trials", type=int, default=20)

    parser.add_argument("--steps", type=int, default=200, help="Optimizer steps per run")
    parser.add_argument("--n_jitters", type=int, default=5, help="Number of jittered starts around BAD_CENTERS[function]")
    parser.add_argument("--epsilon", type=float, default=0.1, help="Jitter radius/scale around bad center")

    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--seed", type=int, default=42)
    
    parser.add_argument("--plot", action="store_true", help="Plot landscape + trajectory after tuning")


    return parser.parse_args()


def save_result(
    result: dict,
    path: str,
    optimizer_name: str
):
    os.makedirs(path, exist_ok=True)
    with open(f'{path}/{optimizer_name}.json', 'w') as f:
        json.dump(result, f)





import matplotlib.pyplot as plt


def _extract_optimizer_params(results: dict, search_space: Dict) -> Dict:
    """Extract from results only those keys that are hyperparameters from search_space."""
    hp_keys = [k for k, v in search_space.items() if isinstance(v, dict)]
    return {k: results[k] for k in hp_keys if k in results}


@torch.no_grad()
def _eval_batch_func(func_torch, xy: torch.Tensor) -> torch.Tensor:
    """
    func_torch: z -> loss, where z shape (...,2), returns (...,) or (...,1)
    xy: (N,2)
    returns: (N,)
    """
    out = func_torch(xy)
    if out.ndim == 2 and out.shape[-1] == 1:
        out = out.squeeze(-1)
    return out


def run_trajectory_from_bad_center(
    function_name: str,
    optimizer_name: str,
    optimizer_params: Dict,
    search_space: Dict,
    device: str,
    steps: int,
    seed: int = 42,
):
    """
    Run optimizer from BAD_CENTER, collect trajectory (x_t, y_t, f_t).
    Returns:
      traj_xy: (T,2) numpy
      traj_f:  (T,)  numpy
    """
    set_global_seed(seed)

    f_torch = TORCH_FUNCTIONS[function_name]
    center_xy = BAD_CENTERS[function_name]
    init_xy = torch.tensor(center_xy, dtype=torch.float32)

    model = Point2D(init_xy.to(device)).to(device)

    optimizer, (clipping, scheduler) = get_optimizer(
        optimizer_name=optimizer_name,
        model=model,
        search_space=search_space,
        trial=None,
        optimizer_params=optimizer_params,
        n_iters=steps,
    )

    traj_xy = []
    traj_f = []
    
    # (optional) record starting point BEFORE first step
    with torch.no_grad():
        z0 = model()
        l0 = f_torch(z0.unsqueeze(0)).squeeze()
        traj_xy.append(z0.detach().float().cpu().numpy())
        traj_f.append(float(l0.detach().cpu().item()))

    for _ in range(steps):
        optimizer.zero_grad(set_to_none=True)

        z = model()  # (2,)
        loss = f_torch(z.unsqueeze(0)).squeeze()
        
        # if loss_tensor.ndim == 2 and loss_tensor.shape[-1] == 1:
        #     loss_tensor = loss_tensor.squeeze(-1)
        # loss = loss_tensor.squeeze(0)  # scalar tensor

        if not torch.isfinite(loss):
            break

        loss.backward()

        if clipping is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type="inf")

        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        # record AFTER step consistent (z_new, loss_new)
        with torch.no_grad():
            z_new = model()
            loss_new = f_torch(z_new.unsqueeze(0)).squeeze()
            if not torch.isfinite(loss_new):
                break
            traj_xy.append(z_new.detach().float().cpu().numpy())
            traj_f.append(float(loss_new.detach().cpu().item()))

    return np.asarray(traj_xy), np.asarray(traj_f)


def plot_3d_landscape_with_trajectory(
    function_name: str,
    traj_xy: np.ndarray,
    x_range=(-5, 5),
    y_range=(-5, 5),
    resolution=180,
    elev=20,
    azim=90,
    cmap="viridis",
    alpha=0.8,
    line_every=5,
    out_path: Optional[str] = None,
):
    """
    Plot function surface + 3D trajectory line.
    If out_path is specified, saves png.
    """
    func_torch = TORCH_FUNCTIONS[function_name]

    # grid
    x = np.linspace(x_range[0], x_range[1], resolution)
    y = np.linspace(y_range[0], y_range[1], resolution)
    X, Y = np.meshgrid(x, y)

    xy = np.stack([X.reshape(-1), Y.reshape(-1)], axis=1)  # (N,2)
    xy_t = torch.tensor(xy, dtype=torch.float32)

    with torch.no_grad():
        Z = _eval_batch_func(func_torch, xy_t).cpu().numpy().reshape(resolution, resolution)

    # trajectory values
    traj_xy_plot = traj_xy[::max(1, int(line_every))]
    traj_t = torch.tensor(traj_xy_plot, dtype=torch.float32)
    with torch.no_grad():
        traj_z = _eval_batch_func(func_torch, traj_t).cpu().numpy()

    fig = plt.figure(figsize=(12, 9))
    ax = fig.add_subplot(111, projection="3d")

    surf = ax.plot_surface(X, Y, Z, cmap=cmap, alpha=alpha, linewidth=0, antialiased=True)

    # trajectory line
    traj_color = "tab:orange"
    ax.plot(traj_xy_plot[:, 0], traj_xy_plot[:, 1], traj_z, linewidth=3, color=traj_color, label="Trajectory")
    ax.scatter(traj_xy_plot[0, 0], traj_xy_plot[0, 1], traj_z[0], s=80, marker="o", label="Start")
    ax.scatter(traj_xy_plot[-1, 0], traj_xy_plot[-1, 1], traj_z[-1], s=80, marker="X", label="End")

    ax.view_init(elev=elev, azim=azim)
    ax.set_xlabel("X (Parameter 1)", fontsize=12)
    ax.set_ylabel("Y (Parameter 2)", fontsize=12)
    ax.set_zlabel("Loss", fontsize=12)
    ax.set_title(f"{function_name}: landscape + tuned trajectory", fontsize=14, pad=20)

    fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10, label="Loss Value")
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    if out_path is not None:
        out_dir = os.path.dirname(out_path)
        if out_dir:  # check that dirname is not empty
            os.makedirs(out_dir, exist_ok=True)
        plt.savefig(out_path, dpi=200)
        plt.close(fig)  # close figure instead of show for headless mode
    else:
        plt.show()
    return fig, ax

def add_best_final_points_to_results(results: dict, traj_xy: np.ndarray, traj_f: np.ndarray) -> dict:
    """
    Adds to results:
      - best_x, best_y, best_step, best_loss_from_bad_center
      - final_x, final_y, final_loss_from_bad_center
    """
    if traj_xy is None or len(traj_xy) == 0:
        return results

    # traj_f may be empty if there was a break somewhere; then we can't recalculate without the function, so just skip
    if traj_f is None or len(traj_f) == 0:
        return results

    best_step = int(np.argmin(traj_f))
    best_xy = traj_xy[best_step]
    final_xy = traj_xy[-1]

    results.update(
        {
            "best_x": float(best_xy[0]),
            "best_y": float(best_xy[1]),
            "best_step": best_step,
            "best_loss_from_bad_center": float(traj_f[best_step]),

            "final_x": float(final_xy[0]),
            "final_y": float(final_xy[1]),
            "final_loss_from_bad_center": float(traj_f[-1]),
        }
    )
    return results




if __name__ == "__main__":
    args = get_arguments()
    
    # Support for multiple optimizers
    optimizers = args.optimizer if isinstance(args.optimizer, list) else [args.optimizer]

    for optimizer_name in optimizers:
        print(f"\n{'='*60}")
        print(f"Running optimizer: {optimizer_name}")
        print(f"{'='*60}\n")
        
        results = tune(
            n_trials=args.n_trials,
            search_space=search_spaces.search_spaces_map[optimizer_name],
            optimizer_name=optimizer_name,
            function_name=args.function,
            device=args.device,
            steps=args.steps,
            n_jitters=args.n_jitters,
            epsilon=args.epsilon,
            n_startup_trials=args.n_startup_trials,
            seed=args.seed,
            path=f"tuning/{args.user}/toy_functions/{args.function}",
        )
        
        # Always run trajectory from BAD_CENTER to record best/final points in JSON
        search_space = search_spaces.search_spaces_map[optimizer_name]
        opt_params = _extract_optimizer_params(results, search_space)

        traj_xy, traj_f = run_trajectory_from_bad_center(
            function_name=args.function,
            optimizer_name=optimizer_name,
            optimizer_params=opt_params,
            search_space=search_space,
            device=args.device,
            steps=args.steps,
            seed=args.seed,
        )
        
        # --- PLOT AFTER TUNE ---
        # If --plot flag is set:
        if args.plot:
            if len(traj_xy) == 0:
                print("Warning: trajectory is empty, skipping plot.")
            else:
                plot_path = f"tuning/{args.user}/toy_functions/{args.function}/{optimizer_name}__trajectory.png"
                
                # Get visualization parameters for specific function
                plot_params = PLOT_PARAMS.get(args.function, {
                    "x_range": (-5, 5),
                    "y_range": (-5, 5),
                    "elev": 30,
                    "azim": 45,
                    "resolution": 200,
                })
                
                plot_3d_landscape_with_trajectory(
                    function_name=args.function,
                    traj_xy=traj_xy,
                    x_range=plot_params["x_range"],
                    y_range=plot_params["y_range"],
                    resolution=plot_params["resolution"],
                    elev=plot_params["elev"],
                    azim=plot_params["azim"],
                    line_every=5,
                    out_path=plot_path,
                )
        
        save_result(
            add_best_final_points_to_results(results, traj_xy, traj_f),
            f"tuning/{args.user}/toy_functions/{args.function}",
            optimizer_name,
        )