import sys
import os
sys.path.append('./')

import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import json

from cifar10_training import get_data, train

from unbalanced_cifar10_training import get_optimizer
from unbalanced_cifar10_training import train as train_unbalanced_cifar10
from unbalanced_cifar10_training import get_data as get_unbalanced_cifar10_data

from tiny_imagenet_training import get_data as get_tiny_imagenet
from tiny_imagenet_training import get_model as get_swin_tiny
from multilingual_nli_training import get_model as get_multilingual_nli_model

from unbalanced_cifar10_training import train as train_unbalanced
from multilingual_nli_training import train as train_multilingual_nli
from tiny_imagenet_training import train as train_swin

from unbalanced_tiny_imagenet_training import get_data as get_unbalanced_tiny_imagenet_data
from unbalanced_tiny_imagenet_training import get_model as get_swin_unbalanced_tiny
from unbalanced_tiny_imagenet_training import train as train_unbalanced_tiny_imagenet

from utils import set_global_seed
import models

BATCH_SIZE = 128

def run_experiment(optimizer_name, ModelCls, dataset, device, optimizer_params, subset_fraction, data_path, seed=42, num_epoches=10, k=10, balanced=False):
    set_global_seed(seed)
    if dataset == 'cifar10':
        train_loader, eval_loader, test_loader = get_data(batch_size=optimizer_params['batch_size'], seed=seed)
    elif dataset == 'unbalanced_cifar10':
        train_loader, eval_loader, test_loader = get_unbalanced_cifar10_data(batch_size=optimizer_params['batch_size'], seed=seed, k=k, balanced=balanced)
    elif dataset == 'tiny_imagenet':
        train_loader, eval_loader, test_loader = get_tiny_imagenet(batch_size=optimizer_params['batch_size'], subset_fraction=subset_fraction, data_path=data_path, seed=seed)
    elif dataset == 'unbalanced_tiny_imagenet': 
        train_loader, eval_loader, test_loader = get_unbalanced_tiny_imagenet_data(batch_size=optimizer_params['batch_size'], subset_fraction=subset_fraction, seed=seed, k=k, data_path=data_path)
    criterion = nn.CrossEntropyLoss()

    grad_accumulation_steps = None
    if dataset == 'tiny_imagenet' or dataset == 'unbalanced_tiny_imagenet':
        grad_accumulation_steps = 4
        n_iters = num_epoches * len(train_loader) // grad_accumulation_steps
    else:
        n_iters = num_epoches * len(train_loader)
    optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space=None, trial=None, optimizer_params=optimizer_params, n_iters=n_iters)
    
    if dataset == 'tiny_imagenet':
        scheduler_outter = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoches)
        val_accuracies, test_accuracies = train_swin(num_epoches=num_epoches, model=model, train_loader=train_loader, 
                                            optimizer=optimizer, criterion=criterion, scheduler_inner=scheduler, scheduler_outter=scheduler_outter,
                                            eval_loader=eval_loader, test_loader=test_loader, 
                                            device=device, clipping=clipping, gradient_accumulation_steps=grad_accumulation_steps)
    elif dataset == 'unbalanced_tiny_imagenet': 
        scheduler_outter = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoches)
        val_accuracies, test_accuracies = train_unbalanced_tiny_imagenet(num_epoches=num_epoches, model=model, train_loader=train_loader, 
                                            optimizer=optimizer, criterion=criterion, scheduler_inner=scheduler, scheduler_outter=scheduler_outter,
                                            eval_loader=eval_loader, test_loader=test_loader, 
                                            device=device, clipping=clipping, gradient_accumulation_steps=grad_accumulation_steps)
    elif dataset == 'unbalanced_cifar10': 
        val_accuracies, test_accuracies = train_unbalanced_cifar10(num_epoches=num_epoches, 
                                    model=model, 
                                    train_loader=train_loader, 
                                    optimizer=optimizer, 
                                    criterion=criterion, 
                                    eval_loader=eval_loader, 
                                    test_loader=test_loader,
                                    device=device,
                                    scheduler=scheduler,
                                    clipping=clipping,
                            )
    elif dataset == 'multilingual_nli':
        val_accuracies, val_perplexities, test_accuracies, test_perplexities = train_multilingual_nli(
                                    num_epoches=num_epoches, 
                                    model=model, 
                                    train_loader=train_loader, 
                                    optimizer=optimizer, 
                                    criterion=criterion, 
                                    eval_loader=eval_loader, 
                                    test_loader=test_loader,
                                    device=device,
                                    scheduler=scheduler,
                                    clipping=clipping,
                                    vocab_size=vocab_size,
                            )
    else:
        val_accuracies, test_accuracies = train(num_epoches=num_epoches, 
                                    model=model, 
                                    train_loader=train_loader, 
                                    optimizer=optimizer, 
                                    criterion=criterion, 
                                    eval_loader=eval_loader, 
                                    test_loader=test_loader,
                                    device=device,
                                    scheduler=scheduler,
                                    clipping=clipping,
                            )
    return val_accuracies, test_accuracies


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--optimizer', type=str, nargs='+', default='AdamW', choices=['Muon',
                                                                           'SignumDL', 'SignumDLNesterov',
                                                                           'Signum', 'SignumLinearLR',
                                                                           'Signum_decoupled_wd', 'Signum_decoupled_wd_LinearLR',
                                                                           'AdamW', 'AdamWBetas', 'Adam', 'AdamEps', 
                                                                           'AdamWClip', 'AdamEpsScheduling', 'AdamBetaScheduling', 'AdamPaLM2', 
                                                                           'SGD', 'SGDLinearLR', 'SGDCosineAnnealingLR', 'SGDClip', 'SGDLinearLR+Clip',
                                                                           'SoftSignumPT', 'SoftSignumPT-auto', 'SoftSignumPT-const',
                                                                           'SoftSignumPT_not_decoupled_wd', 'SoftSignumPT_not_decoupled_wd-auto', 'SoftSignumPT_not_decoupled_wd-const',
                                                                           'SoftSignumSGD', 'SoftSignumSGD-auto', 'SoftSignumSGD-const',
                                                                           'SoftSignumSGD_not_decoupled_wd', 'SoftSignumSGD_not_decoupled_wd-auto', 'SoftSignumSGD_not_decoupled_wd-const',
                                                                           'Signum+SGD', 'Signum+SGD_not_decoupled_wd'],
                        help='Optimizers to use for training.')
    parser.add_argument('--user', type=str, required=True, choices=['user1', 'user2', 'user3'],
                        help='Name of the user folder for getting the optimizer hyperparameters.')
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'tiny_imagenet', 'unbalanced_cifar10', 'unbalanced_tiny_imagenet'],
                        help='Dataset to use.')
    parser.add_argument('--model', type=str, default='SimpleCNN', choices=['SimpleCNN', 'ResNet18_32x32', 'SWIN_tiny',
                                                                           'SimpleCNNBinClass', 'ResNet18_32x32BinClass', 'SWIN_tiny_unbalanced'],
                        help='Model architecture to use.')
    parser.add_argument('--max_epochs', type=int, default=50,
                        help='Maximum number of epochs for training each trial.')
    parser.add_argument('--device', type=str, default='cuda:0' if torch.cuda.is_available() else 'cpu',
                        help='Device to use for training (e.g., "cpu", "cuda:0").')
    parser.add_argument('--use_augmentations', action="store_true", help="Use augmentations")
    parser.add_argument('--subset_fraction', type=float, default=0.25, help='Fraction of training data to use (only for imagenet).')
    parser.add_argument('--data_path', type=str, default='./data/tiny-imagenet-200',
                        help='Root path to the tiny-imagenet-200 directory.')
    parser.add_argument('--unbalance_coef', type=int, default=10,
                        help='Unbalance coefficient k for unbalanced_cifar10 dataset.')
    parser.add_argument('--num_seeds', type=int, default=20,
                        help='Number of random seeds to use for training (default: 20).')
    
    parser.add_argument('--balanced', action='store_true', help='If set, use balanced test dataset.')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_arguments()
    seed_list = list(range(42, 42 + args.num_seeds))
    for optimizer in args.optimizer:
        if args.model == 'SWIN_tiny':
            model_cls = get_swin_tiny
        elif args.model == 'SWIN_tiny_unbalanced': 
            model_cls = get_swin_unbalanced_tiny
        else:
            model_cls = models.model_map[args.model]

        if args.dataset == 'unbalanced_cifar10':
            model_folder = f'{args.model.lower()}_{args.unbalance_coef}'
        elif args.dataset == 'unbalanced_tiny_imagenet': 
            model_folder = f'swin_tiny_{args.unbalance_coef}'
        else:
            model_folder = args.model.lower()
        
        path = f'./tuning/{args.user}/{args.dataset}/{model_folder}/{optimizer}.json'
        with open(path, 'r') as f:
            optimizer_params = json.load(f)
        del optimizer_params['val_score'], optimizer_params['test_score']
        optimizer_params['batch_size'] = args.batch_size
        results = {'val': [], 'test': []}
        
        if args.dataset == 'unbalanced_cifar10' or args.dataset == 'unbalanced_tiny_imagenet': 
            if args.balanced: 
                results_path = f'./tuning/{args.user}/{args.dataset}/{model_folder}/f1_balanced_results/{optimizer}.json'
            else: 
                results_path = f'./tuning/{args.user}/{args.dataset}/{model_folder}/f1_results/{optimizer}.json'
        else: 
            results_path = f'./tuning/{args.user}/{args.dataset}/{model_folder}/results/{optimizer}.json'
        
        results = {'val': [], 'test': []}
        existing_seeds_count = 0
        if os.path.exists(results_path):
            try:
                with open(results_path, 'r') as f:
                    results = json.load(f)
                existing_seeds_count = len(results['val'])
            except (json.JSONDecodeError, KeyError) as e:
                results = {'val': [], 'test': []}
                existing_seeds_count = 0
        
        seeds_to_process = seed_list[existing_seeds_count:]
        if not seeds_to_process:
            continue
        
        
        for seed in seeds_to_process:
            val_accuracies, test_accuracies = run_experiment(
                optimizer_name=optimizer,
                ModelCls=model_cls,
                dataset=args.dataset,
                device=args.device,
                seed=seed,
                optimizer_params=optimizer_params,
                num_epoches=args.max_epochs,
                subset_fraction=args.subset_fraction,
                data_path=args.data_path,
                k=args.unbalance_coef,
                balanced=args.balanced
            )
            results['val'].append(val_accuracies)
            results['test'].append(test_accuracies)
            if args.dataset == 'unbalanced_cifar10' or args.dataset == 'unbalanced_tiny_imagenet': 
                if args.balanced: 
                    save_dir = f'./tuning/{args.user}/{args.dataset}/{model_folder}/f1_balanced_results'
                else: 
                    save_dir = f'./tuning/{args.user}/{args.dataset}/{model_folder}/f1_results'
                os.makedirs(save_dir, exist_ok=True)
                with open(f'{save_dir}/{optimizer}.json', 'w') as f:
                    json.dump(results, f)
            else: 
                os.makedirs(f'./tuning/{args.user}/{args.dataset}/{model_folder}/results', exist_ok=True)
                with open(f'./tuning/{args.user}/{args.dataset}/{model_folder}/results/{optimizer}.json', 'w') as f:
                    json.dump(results, f)