import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import optuna
import copy
import os
import json
from utils import set_global_seed
from optimizers.muon import SingleDeviceMuonWithAuxAdam
from optimizers.signum import Signum, SoftSignum
from optimizers.epsilon import AdamBiasCorrectedEps, Adam
from optimizers.adam_scheduling import AdamPaLM2Beta, AdamBeta2Schedule, AdamEpsilonSchedule
from optimizers.signum_dl import SignumDL
from optimizers.softsign import SoftSignumSGD
import models
import search_spaces
import argparse


def get_data(batch_size, seed=42, use_augmentations=False):
    CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
    CIFAR10_STD =  (0.247, 0.243, 0.261)

    if use_augmentations:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
    ])

    train_dataset_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

    train_len = int(0.8 * len(train_dataset_full))
    val_len = len(train_dataset_full) - train_len
    g = torch.Generator().manual_seed(seed)
    train_subset, val_subset = torch.utils.data.random_split(train_dataset_full, [train_len, val_len], generator=g)

    val_subset = copy.deepcopy(val_subset)
    val_subset.dataset.transform = test_transform

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    return train_loader, val_loader, test_loader


def evaluate_model(model, loader, attack=None, desc="Evaluating", device='cpu'):
    """Evaluates model accuracy on a given data loader, with an optional attack."""
    model.eval()
    correct = 0
    total = 0
    for images, labels in tqdm(loader, desc=desc, leave=False):
        images, labels = images.to(device), labels.to(device)
        if attack:
            with torch.enable_grad():
                adv_images = attack(images, labels)
            outputs = model(adv_images)
        else:
            with torch.no_grad():
                outputs = model(images)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return 100 * correct / total


def train(num_epoches, model, train_loader, optimizer, criterion, eval_loader, test_loader=None, device='cpu', trial=None, clipping=None, scheduler=None):
    eval_accuracies = []
    test_accuracies = []
    for epoch in range(num_epoches):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Training]"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            if clipping is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type='inf')
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            running_loss += loss.item()
        eval_acc = evaluate_model(model, eval_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Evaluating]", device=device)
        test_acc = -1
        if test_loader is not None:
            test_acc = evaluate_model(model, test_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Testing]", device=device)
            test_accuracies.append(test_acc)
        if trial is not None:
            trial.report(eval_acc, epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()
        eval_accuracies.append(eval_acc)
        print(f"Epoch [{epoch+1}/{num_epoches}], Loss: {running_loss/len(train_loader):.4f}, Eval Accuracy: {eval_acc:.2f}%, Test Accuracy: {test_acc:.2f}%")
    return eval_accuracies, test_accuracies


def suggest_params(trial, search_space, prefix=None):
    params = {}
    for param in search_space:
        if isinstance(search_space[param], dict):
            if prefix is not None:
                param_name = prefix +  '_' + param
            else:
                param_name = param
            if search_space[param]['type'] == 'float':
                params[param] = trial.suggest_float(param_name, search_space[param]['min'], search_space[param]['max'], log=search_space[param]['log'])
            else:
                params[param] = trial.suggest_int(param_name, search_space[param]['min'], search_space[param]['max'], log=search_space[param]['log'])
        else:
            params[param] = search_space[param]
    return params


def get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=None, n_iters=None, prefix=None):
    clip = None
    scheduler = None
    if trial is None and optimizer_params is None:
        raise ValueError("Params and trial can not be None together")
    elif trial is not None:
        optimizer_params = suggest_params(trial, search_space, prefix)
    elif trial is None and optimizer_params is not None and prefix is not None:
        tmp = {}
        for param in optimizer_params:
            tmp[param[len(prefix) + 1:]] = optimizer_params[param]
        optimizer_params = tmp

    if optimizer_name == 'Muon':
        if hasattr(model, 'features'):
            hidden_weights = [p for p in model.features.parameters() if p.ndim >= 2]
            hidden_gains_biases = [p for p in model.features.parameters() if p.ndim < 2]
            nonhidden_params = [*model.norm.parameters(), *model.avgpool.parameters(), *model.head.parameters()]
        else:
            hidden_weights = [p for p in model.parameters() if p.ndim >= 2 and p.requires_grad]
            hidden_gains_biases = [p for p in model.parameters() if p.ndim < 2 and p.requires_grad]
            nonhidden_params = []
        param_groups = [
            dict(params=hidden_weights, use_muon=True, lr=optimizer_params['muon_lr'], weight_decay=optimizer_params['muon_weight_decay']),
            dict(params=hidden_gains_biases+nonhidden_params, use_muon=False, lr=optimizer_params['adam_lr'], betas=(0.9, 0.95), weight_decay=optimizer_params['adam_weight_decay']),
        ]
        optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
    elif optimizer_name == 'AdamW':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
    elif optimizer_name == 'Signum':
        optimizer = Signum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
        )
    elif optimizer_name == 'SignumLinearLR':
        optimizer = Signum(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'Signum_decoupled_wd':
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=n_iters,
            decoupled_wd=True,
            hook=hook
        )
    elif optimizer_name == 'Signum_decoupled_wd_LinearLR':
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=n_iters,
            decoupled_wd=True,
            hook=hook
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'SignumDL':
        optimizer = SignumDL(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
        )
    elif optimizer_name == 'SignumDLNesterov':
        optimizer = SignumDL(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            nesterov=True,
        )
    elif optimizer_name == 'SoftSignum':
        optimizer = SoftSignum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
        )
    elif optimizer_name == 'SoftSignum_decoupled_wd':
        optimizer = SoftSignum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            decoupled_wd=True
        )
    elif optimizer_name == 'SoftSignumSGD': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD_not_decoupled_wd': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook,
            decoupled_wd=False
        )
    elif optimizer_name == 'SGD-approx': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'Signum-approx': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=int(1.0 * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD-auto': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD_not_decoupled_wd-auto': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD-const': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD_not_decoupled_wd-const': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook,
            decoupled_wd=False
        )
    elif optimizer_name == 'Signum+SGD': # Only signum iters -> sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            sgd_last=True,
            hook=hook
        )
    elif optimizer_name == 'Signum+SGD_not_decoupled_wd': # Only signum iters -> sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            sgd_last=True,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'Signum+SGD_with_LinearLR': # Only signum iters -> sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            sgd_last=True,
            hook=hook
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'SoftSignumPT': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT-auto': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT-const': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd-auto': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd-const': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'Adam':
        optimizer = Adam(
            model.parameters(), 
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            eps=optimizer_params['eps']
        )
    elif optimizer_name == 'AdamPaLM2':
        optimizer = AdamPaLM2Beta(
            model.parameters(), 
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            betas=(optimizer_params['beta1'], optimizer_params['beta2']),
            beta2_final=optimizer_params['beta2_final']
        )
    elif optimizer_name == 'AdamBetaScheduling':
        optimizer = AdamBeta2Schedule(
            model.parameters(),
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=optimizer_params['warmup_iters'],
        )
    elif optimizer_name == 'AdamEpsScheduling':
        optimizer = AdamEpsilonSchedule(
            model.parameters(),
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=optimizer_params['warmup_iters'],
        )
    elif optimizer_name == 'AdamEps':
        optimizer = AdamBiasCorrectedEps(
            model.parameters(), 
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            eps=optimizer_params['eps']
        )
    elif optimizer_name == 'Muon':
        hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
        hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
        nonhidden_params = [*model.head.parameters(), *model.embed.parameters()]
        param_groups = [
            dict(params=hidden_weights, use_muon=True,
                lr=optimizer_params['muon_lr'], weight_decay=optimizer_params['muon_weight_decay']),
            dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
                lr=optimizer_params['adam_lr'], betas=(0.9, 0.95), weight_decay=optimizer_params['adam_weight_decay']),
        ]
        optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
    elif optimizer_name == 'AdamWClip':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'AdamWBetas':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay'],
            eps=optimizer_params['eps'],
            betas=(optimizer_params['beta1'], optimizer_params['beta2']),
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'SGDClip':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'SGD':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
    elif optimizer_name == 'SGDLinearLR':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            weight_decay=optimizer_params['weight_decay']
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'SGDLinearLR+Clip':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            weight_decay=optimizer_params['weight_decay']
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'SGDCosineAnnealingLR':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            weight_decay=optimizer_params['weight_decay']
        )
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            eta_min=optimizer_params['eta_min'],
            T_max=int(optimizer_params['schedule_iters'] * n_iters)
        )
    else:
        raise NotImplementedError(f"There is no optimizer {optimizer_name} yet")
    return optimizer, (clip, scheduler)


def tune(n_trials, search_space, optimizer_name, ModelCls, device, num_epoches, n_startup_trials, path, seed=42, use_augmentations=False):
    def objective(trial):
        set_global_seed(seed)
        model = ModelCls().to(device)
        train_loader, eval_loader, _ = get_data(batch_size=search_space['batch_size'], use_augmentations=use_augmentations)
        n_iters = num_epoches * len(train_loader)
        criterion = nn.CrossEntropyLoss()
        optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space, trial=trial, n_iters=n_iters)
        val_accuracies, _ = train(num_epoches=num_epoches, 
                                  model=model,
                                  train_loader=train_loader, 
                                  optimizer=optimizer, 
                                  criterion=criterion, 
                                  eval_loader=eval_loader, 
                                  test_loader=None,
                                  device=device,
                                  clipping=clipping,
                                  scheduler=scheduler
                            )
        if trial.number == 0 or trial.study.best_value < max(val_accuracies):
            best_epoch = np.argmax(np.array(val_accuracies))
            result = trial.params | {'val_score': val_accuracies[best_epoch], 'trial': trial.number}
            save_result(result, path, optimizer_name)
        return max(val_accuracies)
    
    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials)

    os.makedirs(os.path.join(path, 'study'), exist_ok=True)
    storage_path = f"sqlite:///{os.path.join(path, 'study', optimizer_name + '.db')}"

    study = optuna.create_study(
        study_name=f"{optimizer_name}",
        storage=storage_path,
        direction="maximize",
        pruner=pruner,
        sampler=sampler,
        load_if_exists=True
    )
    study.optimize(objective, n_trials=n_trials)
    best_trial = study.best_trial

    set_global_seed(seed)
    model = ModelCls().to(device)
    train_loader, eval_loader, test_loader = get_data(batch_size=search_space['batch_size'], use_augmentations=use_augmentations)
    n_iters = num_epoches * len(train_loader)
    criterion = nn.CrossEntropyLoss()
    optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=best_trial.params, n_iters=n_iters)
    val_accuracies, test_accuracies = train(num_epoches=num_epoches, 
                                  model=model, 
                                  train_loader=train_loader, 
                                  optimizer=optimizer, 
                                  criterion=criterion, 
                                  eval_loader=eval_loader, 
                                  test_loader=test_loader,
                                  device=device,
                                  clipping=clipping,
                                  scheduler=scheduler
                            )
    best_epoch = np.argmax(np.array(val_accuracies))
    result = best_trial.params | {'val_score': val_accuracies[best_epoch], 'test_score': test_accuracies[best_epoch]}
    return {key: float(value) for key, value in result.items()}


def get_arguments():
    parser = argparse.ArgumentParser(description="CIFAR-10 Hyperparameter Tuning with Optuna")
    parser.add_argument('--optimizer', type=str, default='AdamW', choices=['Muon',
                                                                           'SignumDL', 'SignumDLNesterov',
                                                                           'Signum', 'SignumLinearLR',
                                                                           'Signum_decoupled_wd', 'Signum_decoupled_wd_LinearLR',
                                                                           'AdamW', 'AdamWBetas', 'Adam', 'AdamEps', 
                                                                           'AdamWClip', 'AdamEpsScheduling', 'AdamBetaScheduling', 'AdamPaLM2', 
                                                                           'SGD', 'SGDLinearLR', 'SGDCosineAnnealingLR', 'SGDClip', 'SGDLinearLR+Clip',
                                                                           'SoftSignumPT', 'SoftSignumPT-auto', 'SoftSignumPT-const',
                                                                           'SoftSignumPT_not_decoupled_wd', 'SoftSignumPT_not_decoupled_wd-auto', 'SoftSignumPT_not_decoupled_wd-const',
                                                                           'SoftSignumSGD', 'SoftSignumSGD-auto', 'SoftSignumSGD-const',
                                                                           'SoftSignumSGD_not_decoupled_wd', 'SoftSignumSGD_not_decoupled_wd-auto', 'SoftSignumSGD_not_decoupled_wd-const',
                                                                           'Signum+SGD', 'Signum+SGD_not_decoupled_wd', 'Signum+SGD_with_LinearLR'],
                        help='Optimizer to use for training.')
    parser.add_argument('--user', type=str, required=True, choices=['user1', 'user2', 'user3'],
                        help='Name of the user folder for saving the results.')
    parser.add_argument('--model', type=str, default='SimpleCNN', choices=['SimpleCNN', 'ResNet18_32x32'],
                        help='Model architecture to use.')
    parser.add_argument('--n_trials', type=int, default=50,
                        help='Number of Optuna trials to run.')
    parser.add_argument('--n_startup_trials', type=int, default=20,
                        help='Number of Optuna startup trials to run.')
    parser.add_argument('--max_epochs', type=int, default=50,
                        help='Maximum number of epochs for training each trial.')
    parser.add_argument('--device', type=str, default='cuda:0' if torch.cuda.is_available() else 'cpu',
                        help='Device to use for training (e.g., "cpu", "cuda:0").')
    parser.add_argument('--use_augmentations', action="store_true", help="Use augmentations")
    args = parser.parse_args()
    return args


def save_result(
    result: dict,
    path: str,
    optimizer_name: str
):
    os.makedirs(path, exist_ok=True)
    with open(f'{path}/{optimizer_name}.json', 'w') as f:
        json.dump(result, f)


if __name__ == '__main__':
    args = get_arguments()

    results = tune(
        n_trials=args.n_trials,
        search_space=search_spaces.search_spaces_map[args.optimizer],
        optimizer_name=args.optimizer,
        ModelCls=models.model_map[args.model],
        device=args.device,
        num_epoches=args.max_epochs,
        n_startup_trials=args.n_startup_trials,
        use_augmentations=args.use_augmentations,
        path=f'tuning/{args.user}/cifar10/{args.model.lower()}'
    )

    save_result(results, f'tuning/{args.user}/cifar10/{args.model.lower()}', args.optimizer)