import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.utils import download_and_extract_archive
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from tqdm import tqdm
import optuna
import copy
import json
import argparse
import os
import shutil
import pandas as pd

from cifar10_training import get_optimizer
from utils import set_global_seed
from optimizers.muon import SingleDeviceMuonWithAuxAdam
from optimizers.signum import Signum, SoftSignum
import search_spaces


def format_tiny_imagenet_val_folder(data_path):
    """
    Formats the Tiny ImageNet validation set for use with torchvision.datasets.ImageFolder.
    This is a necessary one-time pre-processing step.
    'val/images' is reorganized into 'val/{class_id}/...'.
    """

    archive_url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
    if not os.path.exists(data_path):
        download_and_extract_archive(
            url=archive_url,
            download_root='./data',
            extract_root='./data',
            filename='tiny-imagenet-200.zip'
        )

    val_dir = os.path.join(data_path, 'val')
    val_annotations_path = os.path.join(val_dir, 'val_annotations.txt')

    if os.path.exists(os.path.join(val_dir, 'n01443537')):
        return # Already formatted

    print("Formatting Tiny ImageNet validation folder (one-time operation)...")
    val_annotations = pd.read_csv(val_annotations_path, sep='\t', header=None, names=['File', 'Class', 'X_1', 'Y_1', 'X_2', 'Y_2'])

    for _, row in tqdm(val_annotations.iterrows(), total=len(val_annotations), desc="Moving val images"):
        class_dir = os.path.join(val_dir, row['Class'])
        if not os.path.exists(class_dir):
            os.makedirs(class_dir)
        
        original_image_path = os.path.join(val_dir, 'images', row['File'])
        if os.path.exists(original_image_path):
            shutil.move(original_image_path, class_dir)
    
    if os.path.exists(os.path.join(val_dir, 'images')):
        shutil.rmtree(os.path.join(val_dir, 'images'))
    print("Formatting complete.")


def get_data(data_path, batch_size, seed=42, subset_fraction=1.0):
    # ... (transforms and data formatting are the same)
    
    IMAGENET_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STD = (0.229, 0.224, 0.225)
    IMAGE_SIZE = 224

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(IMAGE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

    test_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE + 32), # 256
        transforms.CenterCrop(IMAGE_SIZE),  # 224
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    
    format_tiny_imagenet_val_folder(data_path)
    test_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_path, 'val'), transform=test_transform)
    train_dataset_full = torchvision.datasets.ImageFolder(root=os.path.join(data_path, 'train'), transform=train_transform)

    train_len = int(0.9 * len(train_dataset_full))
    val_len = len(train_dataset_full) - train_len
    g = torch.Generator().manual_seed(seed)
    train_subset, val_subset = torch.utils.data.random_split(train_dataset_full, [train_len, val_len], generator=g)

    if subset_fraction < 1.0:
        subset_train_len = int(subset_fraction * len(train_subset))
        subset_val_len = int(subset_fraction * len(val_subset))
        
        g_subset = torch.Generator().manual_seed(seed)
        train_subset, _ = torch.utils.data.random_split(train_subset, [subset_train_len, len(train_subset) - subset_train_len], generator=g_subset)
        val_subset, _ = torch.utils.data.random_split(val_subset, [subset_val_len, len(val_subset) - subset_val_len], generator=g_subset)
        print(f"Using a {subset_fraction*100:.0f}% subset of the data: {len(train_subset)} train samples, {len(val_subset)} val samples.")


    val_subset = copy.deepcopy(val_subset)
    val_subset.dataset.transform = test_transform

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    return train_loader, val_loader, test_loader



def get_model():
    model = torchvision.models.swin_t(weights=torchvision.models.Swin_T_Weights.IMAGENET1K_V1)
    num_classes = 200
    model.head = nn.Linear(model.head.in_features, num_classes)
    return model


def evaluate_model(model, loader, desc="Evaluating", device='cpu'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc=desc, leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total


def train(num_epoches, model, train_loader, optimizer, criterion, scheduler_outter, scheduler_inner, eval_loader, test_loader=None, device='cpu', trial=None, clipping=None, gradient_accumulation_steps=1, use_amp=True):
    eval_accuracies, test_accuracies = [], []
    
    scaler = GradScaler(enabled=use_amp)

    for epoch in range(num_epoches):
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()
        
        for i, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Training]")):
            images, labels = images.to(device), labels.to(device)
            with autocast(enabled=use_amp):
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss = loss / gradient_accumulation_steps
            scaler.scale(loss).backward()
            
            if (i + 1) % gradient_accumulation_steps == 0:
                if clipping is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type='inf')
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                if scheduler_inner is not None:
                    scheduler_inner.step()

            running_loss += loss.item() * gradient_accumulation_steps

        if scheduler_outter is not None:
            scheduler_outter.step()
        eval_acc = evaluate_model(model, eval_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Evaluating]", device=device)
        eval_accuracies.append(eval_acc)
        test_acc = -1
        if test_loader is not None:
            test_acc = evaluate_model(model, test_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Testing]", device=device)
            test_accuracies.append(test_acc)
        if trial is not None:
            trial.report(eval_acc, epoch)
            if trial.should_prune(): 
                raise optuna.exceptions.TrialPruned()
        current_lr = scheduler_outter.get_last_lr()[0]
        print(f"Epoch [{epoch+1}/{num_epoches}], Loss: {running_loss/len(train_loader):.4f}, LR: {current_lr:.6f}, Eval Acc: {eval_acc:.2f}%, Test Acc: {test_acc:.2f}%")
        
    return eval_accuracies, test_accuracies


def suggest_params(trial, search_space):
    params = {}
    for param, config in search_space.items():
        if isinstance(config, dict):
            if config['type'] == 'float':
                params[param] = trial.suggest_float(param, config['min'], config['max'], log=config.get('log', False))
            else:
                params[param] = trial.suggest_int(param, config['min'], config['max'], log=config.get('log', False))
    return params


def tune(n_trials, data_path, search_space, optimizer_name, device, num_epoches, n_startup_trials, path, seed=42, subset_fraction=1.0):
    """Original tune function, exactly restored."""
    def objective(trial):
        set_global_seed(seed)
        model = get_model().to(device)
        train_loader, eval_loader, _ = get_data(data_path, batch_size=64, subset_fraction=subset_fraction)#search_space['batch_size'])
        n_iters = num_epoches * len(train_loader) / search_space['gradient_accumulation_steps']
        criterion = nn.CrossEntropyLoss()
        optimizer, (clipping, scheduler_inner) = get_optimizer(optimizer_name, model, search_space, trial=trial, n_iters=n_iters)
        
        scheduler_outter = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoches)

        val_accuracies, _ = train(num_epoches=num_epoches, model=model, train_loader=train_loader, 
                                  optimizer=optimizer, criterion=criterion, scheduler_outter=scheduler_outter, scheduler_inner=scheduler_inner,
                                  eval_loader=eval_loader, test_loader=None, device=device, 
                                  trial=trial, clipping=clipping, 
                                  gradient_accumulation_steps=search_space['gradient_accumulation_steps']
        )

        if trial.number == 0 or trial.study.best_value < max(val_accuracies):
            best_epoch = np.argmax(np.array(val_accuracies))
            result = trial.params | {'val_score': val_accuracies[best_epoch], 'trial': trial.number}
            save_result(result, path, optimizer_name)

        return max(val_accuracies)
    
    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials)

    os.makedirs(os.path.join(path, 'study'), exist_ok=True)
    storage_path = f"sqlite:///{os.path.join(path, 'study', optimizer_name + '.db')}"

    study = optuna.create_study(
        study_name=f"{optimizer_name}",
        storage=storage_path,
        direction="maximize",
        pruner=pruner,
        sampler=sampler,
        load_if_exists=True
    )
    study.optimize(objective, n_trials=n_trials)
    
    best_trial = study.best_trial

    set_global_seed(seed)
    model = get_model().to(device)
    train_loader, eval_loader, test_loader = get_data(data_path, batch_size=search_space['batch_size'])
    criterion = nn.CrossEntropyLoss()
    n_iters = num_epoches * len(train_loader) / search_space['gradient_accumulation_steps']
    optimizer, (clipping, scheduler_inner) = get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=best_trial.params, n_iters=n_iters)

    scheduler_outter = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoches)

    val_accuracies, test_accuracies = train(num_epoches=num_epoches, model=model, train_loader=train_loader, 
                                            optimizer=optimizer, criterion=criterion, scheduler_outter=scheduler_outter, scheduler_inner=scheduler_inner,
                                            eval_loader=eval_loader, test_loader=test_loader, 
                                            device=device, clipping=clipping)
    
    best_epoch = np.argmax(np.array(val_accuracies))
    result = best_trial.params | {'val_score': val_accuracies[best_epoch], 'test_score': test_accuracies[best_epoch]}
    return {key: float(value) for key, value in result.items()}


def get_arguments():
    parser = argparse.ArgumentParser(description="Fine-tune Swin-Tiny on Tiny ImageNet with Optuna")
    parser.add_argument('--optimizer', type=str, default='AdamW', choices=['Muon',
                                                                           'SignumDL', 'SignumDLNesterov',
                                                                           'Signum', 'SignumLinearLR',
                                                                           'Signum_decoupled_wd', 'Signum_decoupled_wd_LinearLR',
                                                                           'AdamW', 'Adam', 'AdamEps', 
                                                                           'AdamWClip', 'AdamEpsScheduling', 'AdamBetaScheduling', 'AdamPaLM2', 
                                                                           'SGD', 'SGDLinearLR', 'SGDCosineAnnealingLR', 'SGDClip', 'SGDLinearLR+Clip',
                                                                           'SoftSignumPT', 'SoftSignumPT-auto', 'SoftSignumPT-const',
                                                                           'SoftSignumPT_not_decoupled_wd', 'SoftSignumPT_not_decoupled_wd-auto', 'SoftSignumPT_not_decoupled_wd-const',
                                                                           'SoftSignumSGD', 'SoftSignumSGD-auto', 'SoftSignumSGD-const',
                                                                           'SoftSignumSGD_not_decoupled_wd', 'SoftSignumSGD_not_decoupled_wd-auto', 'SoftSignumSGD_not_decoupled_wd-const',
                                                                           'Signum+SGD', 'Signum+SGD_not_decoupled_wd'],
                        help='Optimizer to use for training.')
    parser.add_argument('--user', type=str, required=True, choices=['user1', 'user2', 'user3'],
                        help='Name of the user folder for saving the results.')
    parser.add_argument('--data_path', type=str, default='./data/tiny-imagenet-200',
                        help='Root path to the tiny-imagenet-200 directory.')
    parser.add_argument('--n_trials', type=int, default=30, help='Number of Optuna trials.')
    parser.add_argument('--n_startup_trials', type=int, default=10, help='Number of Optuna startup trials.')
    parser.add_argument('--max_epochs', type=int, default=20, help='Epochs for training each trial.')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='Device to use for training.')
    parser.add_argument('--subset_fraction', type=float, default=0.25, help='Fraction of training data to use (e.g., 0.25 for 25%).')
    parser.add_argument('--seed', type=int, default=42, help='Global random seed.')
    return parser.parse_args()


def save_result(
    result: dict,
    path: str,
    optimizer_name: str
):
    os.makedirs(path, exist_ok=True)
    with open(f'{path}/{optimizer_name}.json', 'w') as f:
        json.dump(result, f)


if __name__ == '__main__':
    args = get_arguments()
    selected_search_space = search_spaces.search_spaces_map[args.optimizer]
    selected_search_space['batch_size'] = 64
    selected_search_space['gradient_accumulation_steps'] = 4

    results = tune(
        n_trials=args.n_trials,
        data_path=args.data_path,
        search_space=selected_search_space,
        optimizer_name=args.optimizer,
        device=args.device,
        num_epoches=args.max_epochs,
        n_startup_trials=args.n_startup_trials,
        seed=args.seed,
        subset_fraction=args.subset_fraction,
        path=f'tuning/{args.user}/tiny_imagenet/swin_tiny'
    )
    
    save_result(results, f'tuning/{args.user}/tiny_imagenet/swin_tiny', args.optimizer)
