import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.utils import download_and_extract_archive
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from tqdm import tqdm
import optuna
import copy
import json
import argparse
import os
import shutil
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

from unbalanced_cifar10_training import get_optimizer
from utils import set_global_seed
from optimizers.muon import SingleDeviceMuonWithAuxAdam
from optimizers.signum import Signum, SoftSignum
import search_spaces


class UnbalancedTinyImageNet(torch.utils.data.Dataset):
    def __init__(self, dataset, seed: int, k: int = 2):
        print(f"Creating unbalanced dataset (k={k})...")
        
        # Fast access to labels via ImageFolder's targets attribute
        if hasattr(dataset, 'targets'):
            y = torch.tensor(dataset.targets)
        elif hasattr(dataset, 'labels'):
            y = torch.tensor(dataset.labels)
        else:
            # Fallback: iterate through dataset (slower)
            print("Warning: dataset doesn't have .targets attribute, using slow iteration")
            y = torch.tensor([label for _, label in dataset])
        
        print(f"Total samples: {len(y)}")

        # Binarization: even classes -> 0, odd -> 1
        new_targets = y % 2

        # Get indices for major (0) and minor (1) classes
        major_indices = torch.where(new_targets == 0)[0].numpy()
        minor_indices = torch.where(new_targets == 1)[0].numpy()
        
        print(f"Major class size: {len(major_indices)}, Minor class size: {len(minor_indices)}")

        # Reduce minority class by factor k
        if k == 1:
            minor_selected_indices = minor_indices
        else:
            # Stratified sampling to preserve proportions of original classes
            _, minor_selected_indices = train_test_split(
                minor_indices,
                test_size=1.0 / k,
                stratify=y[minor_indices].numpy(),
                random_state=seed
            )
        
        print(f"After unbalancing - Major: {len(major_indices)}, Minor: {len(minor_selected_indices)}")
        
        # Store the original dataset and selected indices
        self._dataset = dataset
        self._indices = np.concatenate([major_indices, minor_selected_indices])
        
        # Store binary labels (0 for major, 1 for minor)
        self._y = torch.cat([
            torch.zeros(len(major_indices), dtype=torch.long),
            torch.ones(len(minor_selected_indices), dtype=torch.long),
        ])
        
        # Store original labels for stratification
        y_major_original = y[major_indices]
        y_minor_original = y[minor_selected_indices]
        self._original_y = torch.cat([y_major_original, y_minor_original])
        
        print(f"Unbalanced dataset created: {len(self)} samples (ratio {len(major_indices)}/{len(minor_selected_indices)})")

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self._indices)

    def __getitem__(self, i: int):
        """
        Returns the i-th sample from the dataset.
        Images are loaded on-demand (lazy loading).

        Args:
            i: Index of the sample to fetch

        Returns:
            Tuple containing (data, target)
        """
        # Get original index from selected indices
        original_idx = self._indices[i]
        # Load image from original dataset
        img, _ = self._dataset[original_idx]
        # Return image with new binary label
        return img, self._y[i]


def stratified_split(dataset, fraction, seed=42):
    """
    Creates a stratified split of a dataset preserving class proportions.
    Preserves proportions of both binary classes (0/1) and original Tiny ImageNet classes.
    
    Args:
        dataset: Dataset with _X, _y, and _original_y attributes (from UnbalancedTinyImageNet)
        fraction: Fraction of data to keep (0.0 to 1.0)
        seed: Random seed for reproducibility
    
    Returns:
        Subset of the dataset with preserved class proportions
    """
    if fraction >= 1.0:
        return dataset
    
    # Extract original labels for stratification
    # Stratifying by original labels automatically preserves:
    # 1. Proportions of original classes (0-199)
    # 2. Proportions of binary classes (0/1), since binary class = original_class % 2
    original_labels = dataset._original_y.numpy()
    
    # Get all indices
    all_indices = np.arange(len(dataset))
    
    # Use train_test_split with stratify on original labels
    # This preserves proportions of all original classes, which automatically
    # preserves proportions of binary classes (since binary = original % 2)
    _, selected_indices = train_test_split(
        all_indices,
        test_size=fraction,
        stratify=original_labels,
        random_state=seed
    )
    
    # Use torch.utils.data.Subset
    return torch.utils.data.Subset(dataset, selected_indices)


def format_tiny_imagenet_val_folder(data_path):
    """
    Formats the Tiny ImageNet validation set for use with torchvision.datasets.ImageFolder.
    This is a necessary one-time pre-processing step.
    'val/images' is reorganized into 'val/{class_id}/...'.
    """

    archive_url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
    if not os.path.exists(data_path):
        print("📦 Downloading Tiny ImageNet-200 (~110MB)...")
        download_and_extract_archive(
            url=archive_url,
            download_root='./data',
            extract_root='./data',
            filename='tiny-imagenet-200.zip'
        )

    val_dir = os.path.join(data_path, 'val')
    val_annotations_path = os.path.join(val_dir, 'val_annotations.txt')

    if os.path.exists(os.path.join(val_dir, 'n01443537')):
        return # Already formatted

    print("Formatting Tiny ImageNet validation folder (one-time operation)...")
    val_annotations = pd.read_csv(val_annotations_path, sep='\t', header=None, names=['File', 'Class', 'X_1', 'Y_1', 'X_2', 'Y_2'])

    for _, row in tqdm(val_annotations.iterrows(), total=len(val_annotations), desc="Moving val images"):
        class_dir = os.path.join(val_dir, row['Class'])
        if not os.path.exists(class_dir):
            os.makedirs(class_dir)
        
        original_image_path = os.path.join(val_dir, 'images', row['File'])
        if os.path.exists(original_image_path):
            shutil.move(original_image_path, class_dir)
    
    if os.path.exists(os.path.join(val_dir, 'images')):
        shutil.rmtree(os.path.join(val_dir, 'images'))
    print("Formatting complete.")


def get_data(data_path, batch_size, seed=42, subset_fraction=1.0, k=2):
    
    IMAGENET_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STD = (0.229, 0.224, 0.225)
    IMAGE_SIZE = 224

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(IMAGE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

    test_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE + 32), # 256
        transforms.CenterCrop(IMAGE_SIZE),  # 224
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    
    format_tiny_imagenet_val_folder(data_path)
    test_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_path, 'val'), transform=test_transform)
    train_dataset_full = torchvision.datasets.ImageFolder(root=os.path.join(data_path, 'train'), transform=train_transform)

    # Apply unbalancing
    train_dataset_full = UnbalancedTinyImageNet(train_dataset_full, seed=seed, k=k)
    test_dataset = UnbalancedTinyImageNet(test_dataset, seed=seed, k=k)
    print('apply unbalancing passed')

    # Apply subset_fraction first (with stratification to preserve class proportions)
    if subset_fraction < 1.0:
        train_dataset_full = stratified_split(train_dataset_full, subset_fraction, seed=seed)
        print(f"Using a {subset_fraction*100:.0f}% subset of the data: {len(train_dataset_full)} train samples.")
    
    # Print statistics for original classes 0-9 (after subset_fraction)
    print("\n📊 Statistics for original classes (0-9) in train_dataset (after subset_fraction):")
    # Handle both UnbalancedTinyImageNet and Subset (from stratified_split)
    if hasattr(train_dataset_full, '_original_y'):
        original_labels = train_dataset_full._original_y.numpy()
    elif hasattr(train_dataset_full, 'dataset') and hasattr(train_dataset_full.dataset, '_original_y'):
        # If it's a Subset, get original labels from the underlying dataset
        subset_indices = train_dataset_full.indices
        original_labels = train_dataset_full.dataset._original_y[subset_indices].numpy()
    else:
        print("  ⚠️  Failed to get original labels")
        original_labels = None
    
    if original_labels is not None:
        for class_id in range(10):
            count = np.sum(original_labels == class_id)
            binary_class = class_id % 2
            print(f"  Class {class_id} (binary class {binary_class}): {count} examples")

    # Then do train/val split
    train_len = int(0.8 * len(train_dataset_full))
    val_len = len(train_dataset_full) - train_len
    g = torch.Generator().manual_seed(seed)
    train_subset, val_subset = torch.utils.data.random_split(train_dataset_full, [train_len, val_len], generator=g)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    return train_loader, val_loader, test_loader


def get_model():
    model = torchvision.models.swin_t(weights=torchvision.models.Swin_T_Weights.IMAGENET1K_V1)
    num_classes = 2  # Binary classification
    model.head = nn.Linear(model.head.in_features, num_classes)
    return model


def evaluate_model(model, loader, desc="Evaluating", device='cpu'):
    """
    Evaluates model F1 score on a given data loader.
    """
    model.eval()
    all_preds = []
    all_labels = []

    for images, labels in tqdm(loader, desc=desc, leave=False):
        images, labels = images.to(device), labels.to(device)
        
        with torch.no_grad():
            outputs = model(images)

        _, predicted = torch.max(outputs.data, 1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    # Calculate F1 score over all collected predictions and labels
    score = f1_score(all_labels, all_preds, zero_division=0)
    
    return score * 100


def train(num_epoches, model, train_loader, optimizer, criterion, scheduler_outter, scheduler_inner, eval_loader, test_loader=None, device='cpu', trial=None, clipping=None, gradient_accumulation_steps=1, use_amp=True):
    eval_accuracies, test_accuracies = [], []
    
    scaler = GradScaler(enabled=use_amp)

    for epoch in range(num_epoches):
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()
        
        for i, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Training]")):
            images, labels = images.to(device), labels.to(device)
            with autocast(enabled=use_amp):
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss = loss / gradient_accumulation_steps
            scaler.scale(loss).backward()
            
            if (i + 1) % gradient_accumulation_steps == 0:
                if clipping is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type='inf')
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                if scheduler_inner is not None:
                    scheduler_inner.step()

            running_loss += loss.item() * gradient_accumulation_steps

        if scheduler_outter is not None:
            scheduler_outter.step()
        eval_acc = evaluate_model(model, eval_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Evaluating]", device=device)
        eval_accuracies.append(eval_acc)
        test_acc = -1
        if test_loader is not None:
            test_acc = evaluate_model(model, test_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Testing]", device=device)
            test_accuracies.append(test_acc)
        if trial is not None:
            trial.report(eval_acc, epoch)
            if trial.should_prune(): 
                raise optuna.exceptions.TrialPruned()
        current_lr = scheduler_outter.get_last_lr()[0]
        print(f"Epoch [{epoch+1}/{num_epoches}], Loss: {running_loss/len(train_loader):.4f}, LR: {current_lr:.6f}, Eval F1: {eval_acc:.2f}%, Test F1: {test_acc:.2f}%")
        
    return eval_accuracies, test_accuracies


def suggest_params(trial, search_space):
    params = {}
    for param, config in search_space.items():
        if isinstance(config, dict):
            if config['type'] == 'float':
                params[param] = trial.suggest_float(param, config['min'], config['max'], log=config.get('log', False))
            else:
                params[param] = trial.suggest_int(param, config['min'], config['max'], log=config.get('log', False))
    return params


def tune(n_trials, data_path, search_space, optimizer_name, device, num_epoches, n_startup_trials, path, seed=42, subset_fraction=1.0, k=2):
    """Tune function for unbalanced Tiny ImageNet."""
    def objective(trial):
        set_global_seed(seed)
        model = get_model().to(device)
        train_loader, eval_loader, _ = get_data(data_path, batch_size=64, subset_fraction=subset_fraction, k=k)
        n_iters = num_epoches * len(train_loader) / search_space['gradient_accumulation_steps']
        criterion = nn.CrossEntropyLoss()
        optimizer, (clipping, scheduler_inner) = get_optimizer(optimizer_name, model, search_space, trial=trial, n_iters=n_iters)
        
        scheduler_outter = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoches)

        val_accuracies, _ = train(num_epoches=num_epoches, model=model, train_loader=train_loader, 
                                  optimizer=optimizer, criterion=criterion, scheduler_outter=scheduler_outter, scheduler_inner=scheduler_inner,
                                  eval_loader=eval_loader, test_loader=None, device=device, 
                                  trial=trial, clipping=clipping, 
                                  gradient_accumulation_steps=search_space['gradient_accumulation_steps']
        )

        if trial.number == 0 or trial.study.best_value < max(val_accuracies):
            best_epoch = np.argmax(np.array(val_accuracies))
            result = trial.params | {'val_score': val_accuracies[best_epoch], 'trial': trial.number}
            save_result(result, path, optimizer_name)

        return max(val_accuracies)
    
    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials)

    os.makedirs(os.path.join(path, 'study'), exist_ok=True)
    storage_path = f"sqlite:///{os.path.join(path, 'study', optimizer_name + '.db')}"

    study = optuna.create_study(
        study_name=f"{optimizer_name}",
        storage=storage_path,
        direction="maximize",
        pruner=pruner,
        sampler=sampler,
        load_if_exists=True
    )
    study.optimize(objective, n_trials=n_trials)
    
    best_trial = study.best_trial

    set_global_seed(seed)
    model = get_model().to(device)
    train_loader, eval_loader, test_loader = get_data(data_path, batch_size=64, subset_fraction=subset_fraction, k=k)
    criterion = nn.CrossEntropyLoss()
    n_iters = num_epoches * len(train_loader) / search_space['gradient_accumulation_steps']
    optimizer, (clipping, scheduler_inner) = get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=best_trial.params, n_iters=n_iters)

    scheduler_outter = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoches)

    val_accuracies, test_accuracies = train(num_epoches=num_epoches, model=model, train_loader=train_loader, 
                                            optimizer=optimizer, criterion=criterion, scheduler_outter=scheduler_outter, scheduler_inner=scheduler_inner,
                                            eval_loader=eval_loader, test_loader=test_loader, 
                                            device=device, clipping=clipping,
                                            gradient_accumulation_steps=search_space['gradient_accumulation_steps'])
    
    best_epoch = np.argmax(np.array(val_accuracies))
    result = best_trial.params | {'val_score': val_accuracies[best_epoch], 'test_score': test_accuracies[best_epoch]}
    return {key: float(value) for key, value in result.items()}


def get_arguments():
    parser = argparse.ArgumentParser(description="Fine-tune Swin-Tiny on Unbalanced Tiny ImageNet with Optuna")
    parser.add_argument('--optimizer', type=str, default='AdamW', choices=['Muon',
                                                                           'SignumDL', 'SignumDLNesterov',
                                                                           'Signum', 'SignumLinearLR',
                                                                           'Signum_decoupled_wd', 'Signum_decoupled_wd_LinearLR',
                                                                           'AdamW', 'Adam', 'AdamEps', 
                                                                           'AdamWClip', 'AdamEpsScheduling', 'AdamBetaScheduling', 'AdamPaLM2', 
                                                                           'SGD', 'SGDLinearLR', 'SGDCosineAnnealingLR', 'SGDClip', 'SGDLinearLR+Clip',
                                                                           'SoftSignumPT', 'SoftSignumPT-auto', 'SoftSignumPT-const',
                                                                           'SoftSignumPT_not_decoupled_wd', 'SoftSignumPT_not_decoupled_wd-auto', 'SoftSignumPT_not_decoupled_wd-const',
                                                                           'SoftSignumSGD', 'SoftSignumSGD-auto', 'SoftSignumSGD-const',
                                                                           'SoftSignumSGD_not_decoupled_wd', 'SoftSignumSGD_not_decoupled_wd-auto', 'SoftSignumSGD_not_decoupled_wd-const',
                                                                           'Signum+SGD', 'Signum+SGD_not_decoupled_wd'],
                        help='Optimizer to use for training.')
    parser.add_argument('--user', type=str, required=True, choices=['user1', 'user2', 'user2'],
                        help='Name of the user folder for saving the results.')
    parser.add_argument('--data_path', type=str, default='./data/tiny-imagenet-200',
                        help='Root path to the tiny-imagenet-200 directory.')
    parser.add_argument('--n_trials', type=int, default=30, help='Number of Optuna trials.')
    parser.add_argument('--n_startup_trials', type=int, default=10, help='Number of Optuna startup trials.')
    parser.add_argument('--max_epochs', type=int, default=20, help='Epochs for training each trial.')
    parser.add_argument('--unbalance_coef', type=int, default=2, help='The unbalance coefficient.')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='Device to use for training.')
    parser.add_argument('--subset_fraction', type=float, default=0.25, help='Fraction of training data to use (e.g., 0.25 for 25%).')
    parser.add_argument('--seed', type=int, default=42, help='Global random seed.')
    return parser.parse_args()


def save_result(
    result: dict,
    path: str,
    optimizer_name: str
):
    os.makedirs(path, exist_ok=True)
    with open(f'{path}/{optimizer_name}.json', 'w') as f:
        json.dump(result, f)


if __name__ == '__main__':
    args = get_arguments()
    selected_search_space = search_spaces.search_spaces_map[args.optimizer]
    selected_search_space['batch_size'] = 64
    selected_search_space['gradient_accumulation_steps'] = 4

    results = tune(
        n_trials=args.n_trials,
        data_path=args.data_path,
        search_space=selected_search_space,
        optimizer_name=args.optimizer,
        device=args.device,
        num_epoches=args.max_epochs,
        n_startup_trials=args.n_startup_trials,
        seed=args.seed,
        subset_fraction=args.subset_fraction,
        k=args.unbalance_coef,
        path=f'tuning/{args.user}/unbalanced_tiny_imagenet/swin_tiny_{args.unbalance_coef}'
    )
    
    save_result(results, f'tuning/{args.user}/unbalanced_tiny_imagenet/swin_tiny_{args.unbalance_coef}', args.optimizer)

