import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import json


DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

BATCH_SIZE = 128
EPOCHS = 50
K = 10


from unbalanced_cifar10_training import get_data, get_optimizer, train
from utils import set_global_seed
import models


def run_experiment(optimizer_name, ModelCls, device, optimizer_params, seed=42, num_epoches=10, k=2):
    set_global_seed(seed)
    model = ModelCls().to(device)
    train_loader, eval_loader, test_loader = get_data(batch_size=optimizer_params['batch_size'], k=k)
    criterion = nn.CrossEntropyLoss()
    optimizer, clipping = get_optimizer(optimizer_name, model, search_space=None, trial=None, optimizer_params=optimizer_params)
    val_accuracies, test_accuracies = train(num_epoches=num_epoches, 
                                  model=model, 
                                  train_loader=train_loader, 
                                  optimizer=optimizer, 
                                  criterion=criterion, 
                                  eval_loader=eval_loader, 
                                  test_loader=test_loader,
                                  device=device,
                                  clipping=clipping,
                            )
    return val_accuracies, test_accuracies


optimizers = ['AdamW', 'AdamWBetas', 'SoftSignum', 'SignumSGD', 'Signum']
results = {}
seed_list = [42, 43, 44, 45, 46]
for optimizer in optimizers:
    with open(f'../tuning/unbalanced_cifar10/{optimizer}.json', 'r') as f:
        optimizer_params = json.load(f)
    del optimizer_params['val_score'], optimizer_params['test_score']
    optimizer_params['batch_size'] = BATCH_SIZE
    for seed in seed_list:
        val_accuracies, test_accuracies = run_experiment(
            optimizer_name=optimizer,
            ModelCls=models.model_map['SimpleCNNBinClass'],
            device=DEVICE,
            seed=seed,
            optimizer_params=optimizer_params,
            num_epoches=EPOCHS,
            k=K,
        )
        if optimizer not in results:
            results[optimizer] = {
                'val': [],
                'test': [],
            }
        results[optimizer]['val'].append(val_accuracies)
        results[optimizer]['test'].append(test_accuracies)


fig, axes = plt.subplots(ncols=2, figsize=(10, 5))
for optimizer in optimizers:
    val_results = np.array(results[optimizer]['val'])
    test_results = np.array(results[optimizer]['test'])
    val_results_mean, val_results_std = val_results.mean(axis=0), val_results.std(axis=0)
    test_results_mean, test_results_std = test_results.mean(axis=0), test_results.std(axis=0)
    axes[0].plot(np.arange(len(val_results_mean)), val_results_mean, label=optimizer)
    axes[0].fill_between(np.arange(len(val_results_mean)), val_results_mean - val_results_std, val_results_mean + val_results_std, alpha=0.3)
    axes[1].plot(np.arange(len(test_results_mean)), test_results_mean, label=optimizer)
    axes[1].fill_between(np.arange(len(test_results_mean)), test_results_mean - test_results_std, test_results_mean + test_results_std, alpha=0.3)


axes[0].set_title("Model Performance on Val Set")
axes[1].set_title("Model Performance on Test Set")
for i in range(2):
    axes[i].set_xlabel("Number of Epoches")
    axes[i].set_ylabel("F-1 score (%)")
    axes[i].set_yticks(np.arange(0, 101, 10))
    axes[i].grid(True, which='both', linestyle='--', linewidth=0.5)
    axes[i].legend()
    axes[i].set_ylim(bottom=0)
plt.savefig("figures/unbalanced/full_scores_1.png")


fig, axes = plt.subplots(ncols=2, figsize=(10, 5))
for optimizer in ['SoftSignum', 'SignumSGD', 'AdamW']:
    val_results = np.array(results[optimizer]['val'])
    test_results = np.array(results[optimizer]['test'])
    val_results_mean, val_results_std = val_results.mean(axis=0), val_results.std(axis=0)
    test_results_mean, test_results_std = test_results.mean(axis=0), test_results.std(axis=0)
    axes[0].plot(np.arange(len(val_results_mean)), val_results_mean, label=optimizer)
    axes[0].fill_between(np.arange(len(val_results_mean)), val_results_mean - val_results_std, val_results_mean + val_results_std, alpha=0.3)
    axes[1].plot(np.arange(len(test_results_mean)), test_results_mean, label=optimizer)
    axes[1].fill_between(np.arange(len(test_results_mean)), test_results_mean - test_results_std, test_results_mean + test_results_std, alpha=0.3)


axes[0].set_title("Model Performance on Val Set")
axes[1].set_title("Model Performance on Test Set")
for i in range(2):
    axes[i].set_xlabel("Number of Epoches")
    axes[i].set_ylabel("F-1 score (%)")
    axes[i].set_yticks(np.arange(0, 101, 10))
    axes[i].grid(True, which='both', linestyle='--', linewidth=0.5)
    axes[i].legend()
    axes[i].set_ylim(bottom=0)
plt.savefig("figures/unbalanced/detailed_scores_1.png")
plt.show()

