import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import optuna
import json
import os
from utils import set_global_seed
from optimizers.muon import SingleDeviceMuonWithAuxAdam
from optimizers.signum import Signum, SoftSignum, SignumSGD
from optimizers.epsilon import AdamBiasCorrectedEps, Adam
import models
import search_spaces
import argparse
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

from optimizers.signum_dl import SignumDL
from optimizers.softsign import SoftSignumSGD


class UnbalancedDataset(torch.utils.data.Dataset):

    def __init__(
        self, balanced_dataset: torch.utils.data.Dataset, seed: int, k: int = 2
    ):
        X, y = [], []
        for el_x, el_y in balanced_dataset:
            X.append(el_x)
            y.append(el_y)
        X = torch.stack(X)
        y = torch.tensor(y)
        new_targets = y % 2
        X_first_class = X[new_targets == 1]
        X_zero_class = X[new_targets == 0]
        if k == 1:
            compressed_indexes = np.arange(X_first_class.shape[0])
        else:
            _, compressed_indexes = train_test_split(
                np.arange(X_first_class.shape[0]),
                test_size=1.0 / k,
                stratify=y[new_targets == 1],
                random_state=seed,
            )
        X_first_class = X_first_class[compressed_indexes]
        self._X = torch.cat([X_zero_class, X_first_class])
        self._y = torch.cat(
            [
                torch.zeros(X_zero_class.shape[0], dtype=y.dtype),
                torch.ones(X_first_class.shape[0], dtype=y.dtype),
            ]
        )

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self._X)

    def __getitem__(self, i: int):
        """
        Returns the i-th sample from the dataset.

        Args:
            i: Index of the sample to fetch

        Returns:
            Tuple containing (data, target)
        """
        el = (self._X[i], self._y[i])
        return el


def get_data(batch_size, seed=42, k=1, balanced=False):

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ])

    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_dataset = UnbalancedDataset(train_dataset, seed=seed, k=k)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    if balanced: 
        test_dataset = UnbalancedDataset(test_dataset, seed=seed, k=1)
    else: 
        test_dataset = UnbalancedDataset(test_dataset, seed=seed, k=k)


    train_len = int(0.8 * len(train_dataset))
    val_len = len(train_dataset) - train_len
    g = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len], generator=g)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    return train_loader, val_loader, test_loader


def evaluate_model(model, loader, attack=None, desc="Evaluating", device='cpu'):
    """
    Evaluates model F1 score on a given data loader, with an optional attack.
    """
    model.eval()
    all_preds = []
    all_labels = []

    for images, labels in tqdm(loader, desc=desc, leave=False):
        images, labels = images.to(device), labels.to(device)
        
        if attack:
            # Generate adversarial examples
            with torch.enable_grad():
                adv_images = attack(images, labels)
            outputs = model(adv_images)
        else:
            # Standard evaluation
            with torch.no_grad():
                outputs = model(images)

        _, predicted = torch.max(outputs.data, 1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    # Calculate F1 score over all collected predictions and labels
    # The 'zero_division=0' argument prevents warnings for classes with no predicted samples
    score = f1_score(all_labels, all_preds, zero_division=0)
    
    return score * 100


def train(num_epoches, model, train_loader, optimizer, criterion, eval_loader, test_loader=None, device='cpu', trial=None, clipping=None, scheduler=None):
    eval_accuracies = []
    test_accuracies = []
    for epoch in range(num_epoches):
        model.train() # Set model to training mode
        running_loss = 0.0
        # Training loop for one epoch
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Training]"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            if clipping is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type='inf')
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            running_loss += loss.item()
        eval_acc = evaluate_model(model, eval_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Evaluating]", device=device)
        test_acc = -1
        if test_loader is not None:
            test_acc = evaluate_model(model, test_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Testing]", device=device)
            test_accuracies.append(test_acc)
        if trial is not None:
            trial.report(eval_acc, epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()
        eval_accuracies.append(eval_acc)
        print(f"Epoch [{epoch+1}/{num_epoches}], Loss: {running_loss/len(train_loader):.4f}, Eval F1: {eval_acc:.2f}%, Test F1: {test_acc:.2f}%")
    return eval_accuracies, test_accuracies


def suggest_params(trial, search_space):
    params = {}
    for param in search_space:
        if isinstance(search_space[param], dict):
            if search_space[param]['type'] == 'float':
                params[param] = trial.suggest_float(param, search_space[param]['min'], search_space[param]['max'], log=search_space[param]['log'])
            else:
                params[param] = trial.suggest_int(param, search_space[param]['min'], search_space[param]['max'], log=search_space[param]['log'])
    return params


def get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=None, n_iters=None):
    clip = None
    scheduler = None
    if trial is None and optimizer_params is None:
        raise ValueError("Params and trial can not be None together")
    if trial is not None:
        optimizer_params = suggest_params(trial, search_space)
    if optimizer_name == 'AdamW':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
    elif optimizer_name == 'Signum':
        optimizer = Signum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
        )
        
    elif optimizer_name == 'SignumLinearLR':
        optimizer = Signum(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'Signum_decoupled_wd':
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=n_iters,
            decoupled_wd=True,
            hook=hook
        )
    elif optimizer_name == 'Signum_decoupled_wd_LinearLR':
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=0,
            only_sign_iters=n_iters,
            decoupled_wd=True,
            hook=hook
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'SoftSignum':
        optimizer = SoftSignum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
        )
    elif optimizer_name == 'SignumSGD':
        optimizer = SoftSignum(
            model.parameters(), 
            sign_lr=optimizer_params['sign_lr'], 
            sgd_lr=optimizer_params['sgd_lr'], 
            sgd_momentum=optimizer_params['sgd_momentum'],
            sign_momentum=optimizer_params['sign_momentum'],
            weight_decay=optimizer_params['weight_decay'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
        )
        
    
    elif optimizer_name == 'Adam':
        optimizer = Adam(
            model.parameters(), 
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            eps=optimizer_params['eps']
        )
    elif optimizer_name == 'AdamEps':
        optimizer = AdamBiasCorrectedEps(
            model.parameters(), 
            lr=optimizer_params['lr'],
            weight_decay=optimizer_params['weight_decay'],
            eps=optimizer_params['eps']
        )
    elif optimizer_name == 'Muon':
        hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
        hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
        nonhidden_params = [*model.head.parameters(), *model.embed.parameters()]
        param_groups = [
            dict(params=hidden_weights, use_muon=True,
                lr=optimizer_params['muon_lr'], weight_decay=optimizer_params['muon_weight_decay']),
            dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
                lr=optimizer_params['adam_lr'], betas=(0.9, 0.95), weight_decay=optimizer_params['adam_weight_decay']),
        ]
        optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
    elif optimizer_name == 'AdamWClip':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'AdamWBetas':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay'],
            eps=optimizer_params['eps'],
            betas=(optimizer_params['beta1'], optimizer_params['beta2']),
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'SGDClip':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
        clip = optimizer_params['clip']
    elif optimizer_name == 'SGD':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            weight_decay=optimizer_params['weight_decay']
        )
    elif optimizer_name == 'SoftSignum_decoupled_wd':
        optimizer = SoftSignum(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            decoupled_wd=True
        )
    elif optimizer_name == 'SoftSignumSGD': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD-auto': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
    elif optimizer_name == 'SoftSignumSGD-const': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook
        )
        
    elif optimizer_name == 'SoftSignumSGD_not_decoupled_wd': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            hook=hook,
            decoupled_wd=False
        )
        
    elif optimizer_name == 'SoftSignumSGD_not_decoupled_wd-auto': # Only signum iters -> warmup soft trannsfer -> almost sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']

        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            decoupled_wd=False,
            hook=hook
        )
                
    elif optimizer_name == 'Signum+SGD': # Only signum iters -> sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            # tmin=optimizer_params['tmin'],
            # tmax=optimizer_params['tmax'],
            warmup_iters=0,
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            sgd_last=True,
            hook=hook
        )
    elif optimizer_name == 'Signum+SGD-like-SoftSignumPT-auto': # Only signum iters -> sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            # tmin=optimizer_params['tmin'],
            # tmax=optimizer_params['tmax'],
            warmup_iters=0,
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            sgd_last=True,
            hook=hook
        )
    elif optimizer_name == 'Signum+SGD_not_decoupled_wd': # Only signum iters -> sgd iters
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            # tmin=optimizer_params['tmin'],
            # tmax=optimizer_params['tmax'],
            warmup_iters=0,
            only_sign_iters=int(optimizer_params['only_sign_iters'] * n_iters),
            sgd_last=True,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT-auto': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT-const': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=optimizer_params['tmin'],
            tmax=optimizer_params['tmax'],
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd-auto': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmin=2.0,
            auto_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
    elif optimizer_name == 'SoftSignumPT_not_decoupled_wd-const': # warmup soft trannsfer
        if 'hook' not in optimizer_params:
            hook = None
        else:
            hook = optimizer_params['hook']
        optimizer = SoftSignumSGD(
            model.parameters(), 
            lr=optimizer_params['lr'], 
            momentum=optimizer_params['momentum'],
            dampening=optimizer_params['momentum'],
            weight_decay=optimizer_params['weight_decay'],
            tmax=optimizer_params['tmax'],
            const_temperature=True,
            warmup_iters=int(optimizer_params['warmup_iters'] * n_iters),
            only_sign_iters=0,
            decoupled_wd=False,
            hook=hook
        )
        
    elif optimizer_name == 'SGDLinearLR':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            weight_decay=optimizer_params['weight_decay']
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
    elif optimizer_name == 'SGDLinearLR+Clip':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=optimizer_params['lr_max'], 
            weight_decay=optimizer_params['weight_decay']
        )
        scheduler = optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=optimizer_params['lr_min'] / optimizer_params['lr_max'],
            total_iters=int(optimizer_params['schedule_iters'] * n_iters)
        )
        clip = optimizer_params['clip']
    else:
        raise NotImplementedError(f"There is no optimizer {optimizer_name} yet")
    return optimizer, (clip, scheduler)


def tune(n_trials, search_space, optimizer_name, ModelCls, device, num_epoches, n_startup_trials, path, seed=42, k=2):
    def objective(trial):
        set_global_seed(seed)
        model = ModelCls().to(device)
        train_loader, eval_loader, _ = get_data(batch_size=search_space['batch_size'], k=k)
        n_iters = num_epoches * len(train_loader)
        criterion = nn.CrossEntropyLoss()
        optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space, trial=trial, n_iters=n_iters)
        val_accuracies, _ = train(num_epoches=num_epoches, 
                                  model=model, 
                                  train_loader=train_loader, 
                                  optimizer=optimizer, 
                                  criterion=criterion, 
                                  eval_loader=eval_loader, 
                                  test_loader=None,
                                  device=device,
                                  clipping=clipping,
                                  scheduler=scheduler
                            )
        if trial.number == 0 or trial.study.best_value < max(val_accuracies):
            best_epoch = np.argmax(np.array(val_accuracies))
            result = trial.params | {'val_score': val_accuracies[best_epoch], 'trial': trial.number}
            save_result(result, path, optimizer_name)
        return max(val_accuracies)
    
    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials)
    study = optuna.create_study(direction="maximize", pruner=pruner, sampler=sampler)
    study.optimize(objective, n_trials=n_trials)
    best_trial = study.best_trial

    set_global_seed(seed)
    model = ModelCls().to(device)
    train_loader, eval_loader, test_loader = get_data(batch_size=search_space['batch_size'], k=k)
    n_iters = num_epoches * len(train_loader)
    criterion = nn.CrossEntropyLoss()
    optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=best_trial.params, n_iters=n_iters)
    val_accuracies, test_accuracies = train(num_epoches=num_epoches, 
                                  model=model, 
                                  train_loader=train_loader, 
                                  optimizer=optimizer, 
                                  criterion=criterion, 
                                  eval_loader=eval_loader, 
                                  test_loader=test_loader,
                                  device=device,
                                  clipping=clipping,
                                  scheduler=scheduler
                            )
    best_epoch = np.argmax(np.array(val_accuracies))
    result = best_trial.params | {'val_score': val_accuracies[best_epoch], 'test_score': test_accuracies[best_epoch]}
    return {key: float(value) for key, value in result.items()}


def get_arguments():
    parser = argparse.ArgumentParser(description="CIFAR-10 Hyperparameter Tuning with Optuna")
    parser.add_argument('--optimizer', type=str, default='AdamW', choices=['Muon',
                                                                           'SignumDL', 'Signum', 'SignumDLNesterov',
                                                                           'AdamW', 'AdamWBetas', 'Adam', 'AdamEps', 
                                                                           'AdamWClip', 'AdamEpsScheduling', 'AdamBetaScheduling', 'AdamPaLM2', 
                                                                           'SGD', 'SGDLinearLR', 'SGDCosineAnnealingLR', 'SGDClip', 'SGDLinearLR+Clip',
                                                                           'SoftSignumPT', 'SoftSignumPT-auto', 'SoftSignumPT-const',
                                                                           'SoftSignumSGD', 'SoftSignumSGD-auto', 'SoftSignumSGD-const',
                                                                           'SoftSignumPT_not_decoupled_wd', 'SoftSignumPT_not_decoupled_wd-auto', 'SoftSignumPT_not_decoupled_wd-const',
                                                                           
                                                                           'Signum+SGD', 
                                                                           'SignumLinearLR', 'Signum_decoupled_wd', 'Signum_decoupled_wd_LinearLR',
                                                                           'Signum+SGD_not_decoupled_wd', 
                                                                           'SoftSignumSGD_not_decoupled_wd', 'SoftSignumSGD_not_decoupled_wd-auto'],
                        help='Optimizer to use for training.')
    parser.add_argument('--user', type=str, required=True, choices=['user1', 'user2', 'user2'],
                        help='Name of the user folder for saving the results.')
    parser.add_argument('--model', type=str, default='SimpleCNNBinClass', choices=['SimpleCNNBinClass', 'ResNet18_32x32BinClass'],
                        help='Model architecture to use.')
    parser.add_argument('--n_trials', type=int, default=50,
                        help='Number of Optuna trials to run.')

    parser.add_argument('--unbalance_coef', type=int, default=2,
                        help='The unbalance coef.')
    
    parser.add_argument('--n_startup_trials', type=int, default=20,
                        help='Number of Optuna startup trials to run.')
    parser.add_argument('--max_epochs', type=int, default=50,
                        help='Maximum number of epochs for training each trial.')
    parser.add_argument('--device', type=str, default='cuda:0' if torch.cuda.is_available() else 'cpu',
                        help='Device to use for training (e.g., "cpu", "cuda:0").')
    args = parser.parse_args()
    return args

def save_result(
    result: dict,
    path: str,
    optimizer_name: str
):
    os.makedirs(path, exist_ok=True)
    with open(f'{path}/{optimizer_name}.json', 'w') as f:
        json.dump(result, f)

if __name__ == '__main__':
    args = get_arguments()

    results = tune(
        n_trials=args.n_trials,
        search_space=search_spaces.search_spaces_map[args.optimizer],
        optimizer_name=args.optimizer,
        ModelCls=models.model_map[args.model],
        device=args.device,
        num_epoches=args.max_epochs,
        n_startup_trials=args.n_startup_trials,
        path=f'tuning/{args.user}/unbalanced_cifar10/{args.model.lower()}_{args.unbalance_coef}',
        k=args.unbalance_coef,
    )
    
    save_result(results, f'tuning/{args.user}/unbalanced_cifar10/{args.model.lower()}_{args.unbalance_coef}', args.optimizer)

    # os.makedirs(f'tuning/{args.user}/unbalanced_cifar10/{args.model.lower()}', exist_ok=True)
    # with open(f'tuning/{args.user}/unbalanced_cifar10/{args.model.lower()}/{args.optimizer}.json', 'w') as f:
    #     json.dump(results, f)
    # with open(f'tuning/unbalanced_cifar10/{args.optimizer}.json', 'w') as f:
    #     json.dump(results, f)
