"""
Train CIFAR10 with PyTorch.

Modified version:
- Removed ALL WD regularization functionality (poly.wd_regularization, sampling points, precompute_matrices, etc.)
- Kept MixUp (EnhancedMixUp) functionality
- Kept plotting/logging structure (reg_terms now recorded as 0.0)
"""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random
import pickle
import matplotlib.pyplot as plt

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2

from EnhancedMixup import EnhancedMixUp
import argparse
import time

from model import get_model
from data import get_data, make_planeloader
from utils import (
    get_loss_function,
    get_scheduler,
    get_random_images,
    produce_plot,
    get_noisy_images,
    AttackPGD,
)
from evaluation import train, test, test_on_trainset, decision_boundary, test_on_adv
from options import options
from utils import (
    simple_lapsed_time,
    adjust_learning_rate,
    adjust_lambda_reg_linear,
    adjust_lambda_reg_sin,
)
from tqdm import tqdm

from check_gpu import print_used_gpus
from set_seed import set_seed, set_seed_detailed
from sam import SAM


def train_only_mixup(args, net, trainloader, optimizer, criterion, device):
    """
    Train function that only keeps standard training + optional mixup.
    WD regularization has been completely removed.
    """
    net.train()
    train_loss_total = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)
        raw_targets = targets  # used for accuracy when not using mixup

        # MixUp (optional)
        if args.mixup_alpha > 0:
            inputs, targets = mixup(inputs, targets)


        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        if args.sam:
            # Generic SAM closure without WD-reg.
            # Note: This assumes your SAM implementation supports optimizer.step(closure)
            def closure():
                optimizer.zero_grad()
                out = net(inputs)
                l = criterion(out, targets)
                l.backward()
                return l

            # Backward is done inside closure
            optimizer.step(closure)
        else:
            loss.backward()
            optimizer.step()

        train_loss_total += loss.item()

        # Accuracy:
        # - if using mixup: targets are soft, accuracy becomes ambiguous -> skip (return 0.0)
        # - if using KL criterion with one-hot, keep previous behavior
        _, predicted = outputs.max(1)
        if 'kl' in args.criterion:
            _, hard_targets = targets.max(1)
        else:
            hard_targets = raw_targets

        total += hard_targets.size(0)
        if args.mixup_alpha > 0:
            # skip accuracy counting to avoid misleading numbers
            pass
        else:
            correct += predicted.eq(hard_targets).sum().item()

        if args.dryrun:
            break

    train_acc = 0.0 if args.mixup_alpha > 0 else (100.0 * correct / total)
    train_loss = train_loss_total / len(trainloader)
    return train_acc, train_loss


def plot_training_curves(
    train_accs,
    test_accs,
    train_losses,
    test_losses,
    reg_terms,
    save_net_name,
    save_path="training_plots",
):
    """Plot training curves."""
    epochs = range(1, len(train_accs) + 1)

    plt.figure(figsize=(12, 5))

    # Accuracy subplot
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_accs, "b-", label="Training Accuracy")
    plt.plot(epochs, test_accs, "r-", label="Test Accuracy")
    plt.title("Training and Test Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)

    # Loss subplot
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_losses, "g-", label="Training Loss")
    plt.plot(epochs, test_losses, "orange", label="Test Loss")

    # reg_terms is kept for compatibility; with WD-reg removed, it's all zeros.
    if any(reg_terms):
        plt.plot(epochs, reg_terms, "m--", label="Regularization Term")

    plt.title("Training and Test Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f"{save_net_name}.png"))
    plt.close()

    # Save raw data
    data = {
        "train_accs": train_accs,
        "test_accs": test_accs,
        "train_losses": train_losses,
        "test_losses": test_losses,
        "reg_terms": reg_terms,
    }
    with open(os.path.join(save_path, f"{save_net_name}_data.pkl"), "wb") as f:
        pickle.dump(data, f)


if __name__ == "__main__":
    args = options().parse_args()
    set_seed(args.set_seed)

    print("Args:")
    for k, v in vars(args).items():
        print("\t{}: {}".format(k, v))

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Create directory for plots
    os.makedirs("training_plots", exist_ok=True)

    print_used_gpus()

    # Init logs
    train_accs = []
    test_accs = []
    train_losses = []
    test_losses = []
    reg_terms = []  # WD-reg removed, will log 0.0
    num_classes = 10  # CIFAR-10

    # Data
    trainloader, testloader = get_data(args)

    # Raw trainset (no augmentation) if needed
    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    raw_trainset = torchvision.datasets.CIFAR10(
        root="~/data", train=True, download=True, transform=transform_test
    )
    raw_trainloader = torch.utils.data.DataLoader(
        raw_trainset, batch_size=args.bs, shuffle=True, num_workers=16
    )

    aug = args.use_data_aug
    use_train_loader = trainloader if aug else raw_trainloader

    set_seed(args.set_seed)

    # Model
    net = get_model(args, device)

    test_acc, predicted = test(args, net, testloader, device, 0)
    print("scratch prediction ", test_acc)

    # Loss
    criterion = get_loss_function(args)
    # Optimizer
    if args.opt.lower() == "sgd":
        if args.sam:
            base_optimizer = optim.SGD
            if args.adaptive_sam:
                print("Using Adaptive SAM optimizer")
                rho = 2.0
                optimizer = SAM(
                    net.parameters(),
                    base_optimizer,
                    rho,
                    adaptive=True,
                    lr=args.lr,
                    momentum=0.9,
                    weight_decay=args.weight_decay,
                    rho=rho,
                )
            else:
                print("Using SAM optimizer")
                rho = 0.05
                optimizer = SAM(
                    net.parameters(),
                    base_optimizer,
                    rho,
                    lr=args.lr,
                    momentum=0.9,
                    weight_decay=args.weight_decay,
                )
        else:
            optimizer = optim.SGD(
                net.parameters(),
                lr=args.lr,
                momentum=0.9,
                weight_decay=args.weight_decay,
            )

    elif args.opt.lower() == "adam":
        optimizer = torch.optim.Adam(
            net.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )

    elif args.opt.lower() == "adamw":
        if args.sam:
            base_optimizer = optim.AdamW
            if args.adaptive_sam:
                print("Using Adaptive SAM optimizer")
                rho = 2.0
                optimizer = SAM(
                    net.parameters(),
                    base_optimizer,
                    rho,
                    adaptive=True,
                    lr=args.lr,
                    weight_decay=args.weight_decay,
                )
            else:
                print("Using SAM optimizer")
                rho = 0.05
                optimizer = SAM(
                    net.parameters(),
                    base_optimizer,
                    rho,
                    lr=args.lr,
                    weight_decay=args.weight_decay,
                )
        else:
            optimizer = torch.optim.AdamW(
                net.parameters(), lr=args.lr, weight_decay=args.weight_decay
            )
    else:
        raise ValueError(f"Unknown optimizer: {args.opt}")

    # MixUp
    if args.mixup_alpha > 0:
        print("Using Mixup Augmentation with alpha =", args.mixup_alpha)
        mixup = EnhancedMixUp(
            alpha=args.mixup_alpha,
            num_classes=num_classes,
            resolution=int(args.resolution),
            pairs=None,
        )

    print("Training the network or loading the network")

    start = time.time()
    best_acc = 0.0

    if args.load_net is None:
        for epoch in range(args.epochs):
            lr = adjust_learning_rate(optimizer, epoch + 1, args)

            # Train (NO WD-reg)
            train_acc, train_loss = train_only_mixup(
                args, net, use_train_loader, optimizer, criterion, device
            )
            train_accs.append(train_acc)
            train_losses.append(train_loss)
            reg_terms.append(0.0)

            # Test acc
            test_acc, predicted = test(args, net, testloader, device, epoch)
            test_accs.append(test_acc)

            # Test loss
            net.eval()
            test_loss = 0.0
            with torch.no_grad():
                for data, target in testloader:
                    data, target = data.to(device), target.to(device)
                    output = net(data)
                    test_loss += criterion(output, target).item()
            test_loss /= len(testloader)
            test_losses.append(test_loss)
            net.train()

            print(
                f"EPOCH: {epoch}/{args.epochs}, LR: {lr:.6f}, "
                f"Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}, "
                f"Train loss: {train_loss:.5f}, Test loss: {test_loss:.5f}"
            )

            if args.dryrun:
                break

            # Plot every 5 epochs
            if epoch % 5 == 0:
                plot_training_curves(
                    train_accs,
                    test_accs,
                    train_losses,
                    test_losses,
                    reg_terms,
                    save_net_name=args.save_net,
                    save_path="training_plots",
                )

            # Save checkpoint logic (unchanged path)
            model_path = f"saved_models/wd_reg/{str(args.set_seed)}/{args.save_net}"
            if test_acc > best_acc:
                print(f"The best epoch is: {epoch}")
                os.makedirs(model_path, exist_ok=True)
                print(f"{model_path}/{args.save_net}.pth")
                best_acc = test_acc
                # NOTE: your original code printed path but didn't actually torch.save()
                # If you want to save, uncomment below:
                # if torch.cuda.device_count() > 1 and isinstance(net, torch.nn.DataParallel):
                #     torch.save(net.module.state_dict(), f"{model_path}/{args.save_net}.pth")
                # else:
                #     torch.save(net.state_dict(), f"{model_path}/{args.save_net}.pth")

        # Final plot
        plot_training_curves(
            train_accs,
            test_accs,
            train_losses,
            test_losses,
            reg_terms,
            save_net_name=args.save_net,
            save_path="training_plots",
        )

    else:
        # Load model
        if isinstance(net, torch.nn.DataParallel):
            net = net.module
        net.load_state_dict(torch.load(args.load_net))

    end = time.time()
    simple_lapsed_time("Time taken to train the model", end - start)
