import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms
import numpy as np
import torch.optim.lr_scheduler as lr_scheduler
import math
import os
from torchvision.transforms import functional as Ft
from torchvision.transforms.functional import InterpolationMode
from torch.amp import autocast, GradScaler
from typing import List, Optional, Tuple, Union
from torch import Tensor
import numbers
import collections
from collections.abc import Sequence
from utils.utils import *
from utils.config import *
from utils.dataset import *
from utils.augmentation import *
import torch.distributed as dist


class SCELoss(nn.Module):
    def __init__(self, alpha, beta):
        super(SCELoss, self).__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, y_pred, y_true_onehot):
        y_pred = F.softmax(y_pred, dim=1)
        y_pred = torch.clamp(y_pred, min=1e-7, max=1.0)
        y_true = torch.clamp(y_true_onehot, min=1e-4, max=1.0)
        ce = -torch.sum(y_true * torch.log(y_pred), dim=1).mean()
        rce = -torch.sum(y_pred * torch.log(y_true), dim=1).mean()

        return self.alpha * ce + self.beta * rce


def trainBaseline(model, device, train_loader, optimizer, lr_scheduler, epoch, dataset=None, scaler=None, params=None, augs=[]):
    model.train()
    epoch_loss, num_batches = 0.0, 0
    label_smoothing, sce_alpha, sce_beta = params
    cross_entropy = nn.CrossEntropyLoss(label_smoothing=label_smoothing) if sce_beta == 0.0 else SCELoss(alpha=sce_alpha, beta=sce_beta)

    for batch_idx, (data, target, _, indices) in enumerate(train_loader):
        data = data.to(device, non_blocking=True)
        targets = target.to(device, non_blocking=True)
        data, targets = augment(data, targets, augs)

        optimizer.zero_grad()
        with autocast('cuda', dtype=torch.bfloat16):
            outputs = model(data)
            loss = cross_entropy(outputs, targets)

        scaler.scale(loss).backward(), scaler.step(optimizer), scaler.update(), lr_scheduler.step()
        epoch_loss += loss.item()
        num_batches += 1

        if (batch_idx + 1) % 100 == 0:
            log_train_loss(epoch, epoch_loss, num_batches, batch_idx, train_loader)
            epoch_loss, num_batches = 0.0, 0

    return epoch_loss / num_batches if num_batches > 0 else 0.0


def trainKD(model, device, train_loader, optimizer, lr_scheduler, epoch, dataset=None, scaler=None, params=None, augs=[]):
    model.train()
    epoch_loss, num_batches = 0.0, 0
    cross_entropy, kl_div = nn.CrossEntropyLoss(), nn.KLDivLoss(reduction='batchmean')
    temperature = params[0]

    for batch_idx, (data, teacher_preds, target, indices) in enumerate(train_loader):
        data = data.to(device, non_blocking=True)
        targets = target.to(device, non_blocking=True)
        teacher_preds = teacher_preds.to(device, non_blocking=True)

        optimizer.zero_grad()
        with autocast('cuda', dtype=torch.bfloat16):
            outputs = model(data)

            student_log_probs = F.log_softmax(outputs / temperature, dim=1)
            soft_loss = kl_div(student_log_probs, teacher_preds)
            hard_loss = cross_entropy(outputs, targets)
            loss = 0.5 * (temperature ** 2) * soft_loss + 0.5 * hard_loss

        scaler.scale(loss).backward(), scaler.step(optimizer), scaler.update(), lr_scheduler.step()
        epoch_loss += loss.item()
        num_batches += 1

        if (batch_idx + 1) % 100 == 0:
            log_train_loss(epoch, epoch_loss, num_batches, batch_idx, train_loader)
            epoch_loss, num_batches = 0.0, 0

    return epoch_loss / num_batches if num_batches > 0 else 0.0


def trainBSD(model, device, train_loader, optimizer, lr_scheduler, epoch, dataset=None, scaler=None, params=[0.99, 1.0], augs=[]):
    model.train()
    epoch_loss, num_batches = 0.0, 0
    gamma, tau = params

    kl_div = nn.KLDivLoss(reduction='batchmean')

    for batch_idx, (data, soft_targets, _, indices) in enumerate(train_loader):
        data = data.to(device, non_blocking=True)
        targets = soft_targets.to(device, non_blocking=True)
        
        num_aug = int(data.size(0) * 0.5) if len(augs) > 0 else 0
        if num_aug > 0:
            data[:num_aug], targets[:num_aug] = augment(data[:num_aug], targets[:num_aug], augs)
        
        with torch.no_grad():
            A = dataset.effective_counts[indices].to(device)
            beta = (gamma * A) / (gamma * A + 1.0)
            beta = beta.view(-1, 1)

        optimizer.zero_grad(set_to_none=True)
        with autocast('cuda', dtype=torch.bfloat16):
            outputs = model(data)
            log_p = F.log_softmax(outputs, dim=1)
            loss = kl_div(log_p, sharpen(targets, tau))

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        lr_scheduler.step()
        
        epoch_loss += loss.item()
        num_batches += 1

        with torch.no_grad():
            probs = F.softmax(outputs, dim=1).detach()
            
            if num_aug > 0:
                indices = indices[num_aug:]
                probs = probs[num_aug:]
                beta = beta[num_aug:]
                
            dataset.buffered_targets[indices] = (
                beta.cpu() * dataset.buffered_targets[indices] +
                (1.0 - beta.cpu()) * probs.cpu()
            )
            dataset.effective_counts[indices] = (
                gamma * dataset.effective_counts[indices] + 1.0
            )

        if (batch_idx + 1) % 100 == 0:
            log_train_loss(epoch, epoch_loss, num_batches, batch_idx, train_loader)
            epoch_loss, num_batches = 0.0, 0
    
    return epoch_loss / num_batches if num_batches > 0 else 0.0


def trainBSDPlus(model, device, train_loader, optimizer, lr_scheduler, epoch, dataset=None, scaler=None, params=[0.9, 3, 1.0], augs=[]):
    model.train()
    epoch_loss, num_batches = 0.0, 0
    lambda_c, gamma, tau = params
    kl_div = nn.KLDivLoss(reduction='batchmean')

    for batch_idx, (data_simple, data_aug_list, buffer_targets, indices) in enumerate(train_loader):
        data_simple = data_simple.to(device)
        targets = buffer_targets.to(device)

        with torch.no_grad():
            A = dataset.effective_counts[indices].to(device)
            beta = (gamma * A) / (gamma * A + 1.0)
            beta = beta.view(-1, 1)

        optimizer.zero_grad()
        with autocast('cuda', dtype=torch.bfloat16): 
            output = model(data_simple) 
            probs = F.softmax(output, dim=1)
            log_p = F.log_softmax(output, dim=1)
            
            detached_probs = probs.detach()

            out_aug = torch.zeros((len(data_aug_list), data_simple.shape[0], dataset.num_classes), device=device)
            targets_aug = torch.zeros((len(data_aug_list), data_simple.shape[0], dataset.num_classes), device=device)
            for i in range(len(data_aug_list)):
                data_aug, targets_aug[i] = data_aug_list[i].to(device), detached_probs
                data_aug, targets_aug[i] = augment(data_aug, targets_aug[i], augs)
                out_aug[i] = model(data_aug)

            loss = kl_div(log_p, sharpen(targets, tau)) + lambda_c*1/2*(kl_div(F.log_softmax(out_aug[0], dim=1), targets_aug[0]) + kl_div(F.log_softmax(out_aug[1], dim=1), targets_aug[1]))

        scaler.scale(loss).backward(), scaler.step(optimizer), scaler.update(), lr_scheduler.step()
        epoch_loss += loss.item()
        num_batches += 1

        dataset.buffered_targets[indices] = (
            beta.cpu() * dataset.buffered_targets[indices] +
            (1.0 - beta.cpu()) * detached_probs.cpu()
        )
        dataset.effective_counts[indices] = (
            gamma * dataset.effective_counts[indices] + 1.0
        )

        if (batch_idx + 1) % 100 == 0:
            log_train_loss(epoch, epoch_loss, num_batches, batch_idx, train_loader)
            epoch_loss, num_batches = 0.0, 0

    return (epoch_loss / num_batches) if num_batches > 0 else 0.0


def trainTE(model, device, train_loader, optimizer, lr_scheduler, epoch, dataset=None, scaler=None, params=[0.6, 100, 30, 50, 200], augs=[]):
    model.train()
    epoch_loss, num_batches = 0.0, 0
    mse, cross_entropy = nn.MSELoss(), nn.CrossEntropyLoss()
    momentum, ramp_up_epochs, w_max, ramp_down_epochs, total_epochs = params
    ramp_down_epochs = 50

    w = w_max * np.exp(-5 * (1 - epoch / ramp_up_epochs)**2) if epoch < ramp_up_epochs else w_max

    if epoch >= total_epochs - ramp_down_epochs:
        T = (total_epochs - epoch) / ramp_down_epochs
        factor = math.exp(-12.5 * (1 - T)**2)
        beta1 = 0.5 + (0.9 - 0.5) * factor
        for param_group in optimizer.param_groups:
            # Only modify betas for Adam/AdamW optimizers
            if 'betas' in param_group:
                param_group['betas'] = (beta1, param_group['betas'][1])

    for batch_idx, (data, Z, target, indices) in enumerate(train_loader):
        data = data.to(device)
        targets = target.to(device)
        data, targets = augment(data, targets, augs)
        Z = Z.to(device)

        optimizer.zero_grad()

        with autocast('cuda', dtype=torch.bfloat16):
            outputs = model(data)
            probs = F.softmax(outputs, dim=1)
            loss = cross_entropy(outputs, targets) + w * mse(probs, Z / (1 - momentum**epoch))
        
        scaler.scale(loss).backward(), scaler.step(optimizer), scaler.update(), lr_scheduler.step()
        epoch_loss += loss.item()
        num_batches += 1

        with torch.no_grad():
            probs = F.softmax(outputs, dim=1).detach().cpu()
            dataset.buffered_targets[indices] = (
                momentum * dataset.buffered_targets[indices] +
                (1 - momentum) * probs
            )

        if (batch_idx + 1) % 100 == 0:
            log_train_loss(epoch, epoch_loss, num_batches, batch_idx, train_loader)
            epoch_loss, num_batches = 0.0, 0

    return epoch_loss / num_batches if num_batches > 0 else 0.0


def trainDLB(model, device, train_loader, optimizer, lr_scheduler, epoch, dataset=None, scaler=None, params=[3, 1], augs=[]):
    model.train()
    epoch_loss, num_batches = 0.0, 0
    kl_divergence, cross_entropy = nn.KLDivLoss(reduction='batchmean'), nn.CrossEntropyLoss()
    temp, alpha = params
    last_logits, last_data, last_targets = None, None, None

    for batch_idx, (data, _, target, indices) in enumerate(train_loader):
        data = data.to(device)
        targets = target.to(device)
        data, targets = augment(data, targets, augs)
        
        optimizer.zero_grad()
        if last_data is not None:
            all_data = torch.concat((last_data, data))
            all_targets = torch.concat((last_targets, targets))
        else:
            all_data = data
            all_targets = targets

        with autocast('cuda', dtype=torch.bfloat16):
            outputs = model(all_data)
            loss = cross_entropy(outputs, all_targets)
        
        if last_logits is not None:
            soft_targets = F.softmax(last_logits / temp, dim=1)
            log_probabilities = F.log_softmax(outputs[:last_logits.shape[0]] / temp, dim=1)
            loss += alpha * temp**2 * kl_divergence(log_probabilities, soft_targets)

        scaler.scale(loss).backward(), scaler.step(optimizer), scaler.update(), lr_scheduler.step()
        epoch_loss += loss.item()
        num_batches += 1

        if last_logits is None:
            last_logits = outputs.detach()
        else:
            last_logits = outputs[targets.shape[0]:].detach()
        last_data = data
        last_targets = targets

        if (batch_idx + 1) % 100 == 0:
            log_train_loss(epoch, epoch_loss, num_batches, batch_idx, train_loader)
            epoch_loss, num_batches = 0.0, 0

    return epoch_loss / num_batches if num_batches > 0 else 0.0


def trainPSKD(model, device, train_loader, optimizer, lr_scheduler, epoch, dataset=None, scaler=None, params=[0.8, 200], augs=[]):
    model.train()
    epoch_loss = 0.0
    num_batches = 0
    cross_entropy = nn.CrossEntropyLoss()
    alpha_T, total_epochs = params
    alpha = alpha_T * epoch / total_epochs

    for batch_idx, (data, last_preds, target, indices) in enumerate(train_loader):
        data = data.to(device)
        targets = F.one_hot(target.to(device), num_classes=dataset.num_classes).float()
        data, targets = augment(data, targets, augs)

        optimizer.zero_grad()

        with autocast('cuda', dtype=torch.bfloat16):
            outputs = model(data)
            loss = cross_entropy(outputs, (1 - alpha) * targets + alpha * last_preds.to(device))

        scaler.scale(loss).backward(), scaler.step(optimizer), scaler.update(), lr_scheduler.step()
        epoch_loss += loss.item()
        num_batches += 1

        dataset.buffered_targets[indices] = F.softmax(outputs.detach().cpu().float(), dim=1)

        if (batch_idx + 1) % 100 == 0:
            log_train_loss(epoch, epoch_loss, num_batches, batch_idx, train_loader)
            epoch_loss, num_batches = 0.0, 0

    return epoch_loss / num_batches if num_batches > 0 else 0.0


def sharpen(p, tau):
    p = torch.pow(p, 1.0 / tau)
    p = p / p.sum(dim=1, keepdim=True)
    return p

def log_train_loss(epoch, loss, num_batches, batch_idx, train_loader):
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
        avg_loss = loss / num_batches
        print(f"Epoch [{epoch}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {avg_loss:.4f}", flush=True)

def augment(data, target, augs=[]):
    for aug in augs:
        data, target = aug(data, target)
    return data, target


def validate(model, device, val_loader, num_classes, distributed=False):
    model.eval()
    
    # Accumulators for local metrics
    local_loss_sum = 0.0
    local_correct = 0
    local_total = 0
    
    # Use reduction='batchmean' consistent with typical KL Div usage, 
    # but we multiply by batch size later to get the sum for aggregation.
    kl_divergence = nn.KLDivLoss(reduction='batchmean')

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            
            # Forward pass
            outputs = model(data)

            # Calculate Loss
            # Note: We construct targets on the fly here
            targets_one_hot = F.one_hot(target.long(), num_classes=num_classes).float()
            log_probabilities = F.log_softmax(outputs, dim=1)
            
            loss = kl_divergence(log_probabilities, targets_one_hot)
            
            # Calculate Accuracy
            _, predicted = torch.max(outputs.data, 1)
            
            # Accumulate locally
            batch_size = target.size(0)
            local_loss_sum += loss.item() * batch_size # Convert mean back to sum
            local_correct += (predicted == target).sum().item()
            local_total += batch_size

    # If Distributed, aggregate metrics from all GPUs
    if distributed:
        # Pack metrics into a tensor for all_reduce: [loss_sum, correct, total]
        metrics = torch.tensor([local_loss_sum, local_correct, local_total], device=device, dtype=torch.float64)
        
        # Sum the metrics across all processes
        dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
        
        # Unpack back to variables
        global_loss_sum = metrics[0].item()
        global_correct = metrics[1].item()
        global_total = metrics[2].item()
    else:
        global_loss_sum = local_loss_sum
        global_correct = local_correct
        global_total = local_total

    # Avoid division by zero
    if global_total == 0:
        avg_loss = 0.0
        accuracy = 0.0
    else:
        avg_loss = global_loss_sum / global_total
        accuracy = 100.0 * global_correct / global_total

    # Only Rank 0 prints the output
    if not distributed or dist.get_rank() == 0:
        print(f"Validation Loss: {avg_loss:.4f}, Validation Accuracy: {accuracy:.2f}%", flush=True)
        
    return avg_loss, accuracy


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def parse_args():
    parser = argparse.ArgumentParser(description="Train a model on CIFAR-10 with optional checkpointing and dataset selection.")
    parser.add_argument('--model', type=str, default='ResNet18', help='Model architecture to use (default: ResNet18)')
    parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs (default: 200)')
    parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to a checkpoint file to resume training (default: None)')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='Dataset to use for training (default: CIFAR10)')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size (default: 256)')
    parser.add_argument('--lr', type=float, default=1e-5, help='Initial learning rate (default: 1e-5)')
    parser.add_argument('--max_lr', type=float, default=1e-2, help='Maximum learning rate for OneCycleLR (default: 1e-2)')
    parser.add_argument('--weight_decay', type=float, default=0, help='Weight decay (default: 0)')
    parser.add_argument('--gamma', type=float, default=0.99, help='Strength of discounting')
    parser.add_argument('--noise_rate', type=float, default=0.0, help='Fraction of labels to flip as noise (default: 0.0)')
    parser.add_argument('--noise_type', type=str, default="sym", help='Type of label noise to use (default: sym)')
    parser.add_argument('--save_every', type=int, default=40, help='Save checkpoint every n epochs (default: 40)')
    parser.add_argument('--seed', type=int, default=0, help='Seed for training (default: 0)')
    parser.add_argument('--noise_seed', type=int, default=0, help='Seed for noise (default: 0)')
    parser.add_argument('--teacher_model', type=str, default=None, help='Model for distillation (default: None)')
    parser.add_argument('--teacher_model_paths', nargs='+', type=str, required=False, help='List of file paths for teacher models.')
    parser.add_argument('--method', type=str, default="BSD", help='Method for distillation (default: BSD)')
    parser.add_argument('--ramp_up_epochs', type=int, default=100, help='Epochs increasing w_max for TE (default: 100)')
    parser.add_argument('--ramp_down_epochs', type=int, default=50, help='Epochs decreasing beta1 in Adam for TE (default: 50)')
    parser.add_argument('--w_max', type=int, default=30, help='Maximum temporal weight for TE (default: 30)')
    parser.add_argument('--temp', type=float, default=3.0, help='Temperature for DLB/KD (default: 3.0)')
    parser.add_argument('--alpha', type=float, default=1.0, help='Alpha for DLB (default: 1.0)')
    parser.add_argument('--num_workers', type=int, default=12, help='Number of workers (default: 12)')
    parser.add_argument('--distributed', action='store_true', help='Use distributed training (default: False)')
    parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd', 'lamb', 'lars'], help='Optimizer to use (default: adam)')
    parser.add_argument('--lr_schedule', type=str, default='onecycle', help='LR scheduling strategy (default: onecycle)')
    parser.add_argument('--warmup_epochs', type=int, default=10, help='Warm-up epochs for cosine_warmup (default: 10)')
    parser.add_argument('--cutmix', action='store_true', help='Use cutmix during training (default: False)')
    parser.add_argument('--mixup', action='store_true', help='Use mixup during training (default: False)')
    parser.add_argument('--cutout', action='store_true', help='Use cutout during training (default: False)')
    parser.add_argument("--label_smoothing", type=float, default=0.0, help="Amount of label smoothing to apply for baseline (default: 0.0)")
    parser.add_argument("--sce_beta", type=float, default=0.0, help="Beta for SCE loss (default: 0.0).")
    parser.add_argument("--sce_alpha", type=float, default=1.0, help="Alpha for SCE loss (default: 1.0).")
    parser.add_argument("--lambda_c", type=float, default=1.0, help="Weight of the constrastive loss for BSD+ (default: 1.0).")
    parser.add_argument("--tau", type=float, default=1.0, help="Temperature for sharpening (default: 1.0).")
    parser.add_argument('--val', action='store_true', help='Enables training on validation set.')
    parser.add_argument("--c", type=float, default=100.0, help="Temperature for sharpening (default: 0.5).")
    parser.add_argument('--momentum_te', type=float, default=0.6, help='Momentum for TE (default: 0.6)')
    parser.add_argument('--eps', type=float, default=0.0, help='Epsilon for BSD (default: 0.0)')


    return parser.parse_args()

def main():
    torch.backends.cudnn.benchmark = True

    args = parse_args()
    args_dict = vars(args)
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
        print("Arguments: ", ', '.join(f"{key}={value}" for key, value in args_dict.items()))

    set_seed(args.seed)

    if args.distributed:
        import torch.distributed as dist
        dist.init_process_group(backend='nccl')
        local_rank = int(os.environ["LOCAL_RANK"])
        global_rank = dist.get_rank()
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
        
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    method_configs = {
        "bsd":      (trainBSD,      ["gamma","tau"]),
        "bsd+":     (trainBSDPlus,  ["lambda_c", "gamma", "tau"]),
        "te":       (trainTE,       ["momentum_te", "ramp_up_epochs", "w_max", "ramp_down_epochs", "epochs"]),
        "dlb":      (trainDLB,      ["temp", "alpha"]),
        "baseline": (trainBaseline, ["label_smoothing", "sce_alpha", "sce_beta"]),
        "kd":       (trainKD,       ["temp"]),
        "pskd":     (trainPSKD,     [0.3, "epochs"]),
    }

    method = args.method.lower()

    if method not in method_configs:
        raise ValueError(f"Unknown method: {args.method}")

    train, param_keys = method_configs[method]
    params = [key if isinstance(key, (int, float)) else getattr(args, key) for key in param_keys]

    train_dataset, val_dataset, test_dataset = load_datasets(args)
    if not args.val:
        val_dataset = test_dataset

    if method in ['bsd', 'bsd+']:
        A = args.c + (train_dataset.num_classes - 1) * args.eps
        train_dataset.buffered_targets = (train_dataset.buffered_targets * (args.c - args.eps) + args.eps) / A
        train_dataset.effective_counts[:] = A

    num_classes = train_dataset.num_classes
    all_augs = [CutMix(), MixUp(), Cutout(n_holes=1, args=args, length=train_dataset.img_size[0]//2)]
    selected_augs = [a for u, a in zip([args.cutmix, args.mixup, args.cutout], all_augs) if u]
        
    if args.distributed:
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(train_dataset, shuffle=True)
        val_sampler = DistributedSampler(val_dataset, shuffle=False)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, pin_memory=True, prefetch_factor=2)
        val_loader  = DataLoader(val_dataset, batch_size=args.batch_size, sampler=val_sampler, num_workers=args.num_workers, pin_memory=True, prefetch_factor=2)
    else:
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, prefetch_factor=2)
        val_loader  = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, prefetch_factor=2)

    model = load_model(args.model, num_classes)
    model.to(device)
    
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
        print("Total parameters:", sum(p.numel() for p in model.parameters()))

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device.index], broadcast_buffers=False)
    
    decay, no_decay = [], []
    for module_name, module in model.named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            full_param_name = f"{module_name}.{param_name}" if module_name else param_name
            if not param.requires_grad:
                continue
            if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d, nn.LayerNorm)):
                no_decay.append(param)
            elif param_name.endswith("bias"):
                no_decay.append(param)
            else:
                decay.append(param)

    #model = torch.compile(model)

    optimizer_grouped_parameters = [
        {"params": decay, "weight_decay": args.weight_decay},
        {"params": no_decay, "weight_decay": 0.0}
    ]

    if args.optimizer.lower() == 'adam':
        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.lr)
    elif args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(optimizer_grouped_parameters, lr=args.lr, momentum=0.9)

    if args.lr_schedule.lower() == 'onecycle':
        lr_scheduler_instance = lr_scheduler.OneCycleLR(
            optimizer, 
            max_lr=args.max_lr, 
            steps_per_epoch=len(train_loader), 
            epochs=args.epochs
        )
    elif args.lr_schedule.lower() == 'cosine_warmup':
        total_steps = len(train_loader) * args.epochs
        warmup_steps = args.warmup_epochs * len(train_loader)
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                warmup_factor = 0.01
                alpha = float(current_step) / float(max(1, warmup_steps))
                return warmup_factor + alpha * (1.0 - warmup_factor)
            progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))
        lr_scheduler_instance = lr_scheduler.LambdaLR(optimizer, lr_lambda)
    elif args.lr_schedule.lower() == 'step':
        total_steps = len(train_loader) * args.epochs
        lr_scheduler_instance = lr_scheduler.MultiStepLR(optimizer, milestones=[80/160*total_steps,120/160*total_steps], gamma=0.1)

    scaler = GradScaler()
    save_path = os.path.join(SAVE_PATH, args.dataset.lower(), args.model.lower(), args.method.lower(), str(args.seed))
    os.makedirs(save_path, exist_ok=True)

    start_epoch = 1
    if args.checkpoint_path is not None:
        loaded_epoch = load_checkpoint(args.checkpoint_path, model, optimizer, lr_scheduler_instance)
        if loaded_epoch is not None:
            start_epoch = loaded_epoch + 1

    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
        print("Starting Training...")
    for epoch in range(start_epoch, args.epochs + 1):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            targets_before = train_dataset.buffered_targets.clone()
            counts_before = train_dataset.effective_counts.clone()

        train(model, device, train_loader, optimizer, lr_scheduler_instance, epoch, dataset=train_dataset, scaler=scaler, params=params, augs=selected_augs)

        if args.distributed:
            targets_delta = train_dataset.buffered_targets - targets_before
            counts_delta = train_dataset.effective_counts - counts_before

            gpu_targets_delta = targets_delta.to(device)
            gpu_counts_delta = counts_delta.to(device)
            
            dist.all_reduce(gpu_targets_delta, op=dist.ReduceOp.SUM)
            dist.all_reduce(gpu_counts_delta, op=dist.ReduceOp.SUM)
            
            targets_synced = targets_before + gpu_targets_delta.cpu()
            counts_synced = counts_before + gpu_counts_delta.cpu()

            train_dataset.buffered_targets.copy_(targets_synced)
            train_dataset.effective_counts.copy_(counts_synced)

        val_loss, val_accuracy = validate(model, device, val_loader, train_dataset.num_classes, distributed=args.distributed)

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            if epoch % args.save_every == 0:
                save_checkpoint(args, save_path, epoch, model, optimizer, lr_scheduler_instance)


if __name__ == "__main__":
    main()