"""
Train CIFAR10 with PyTorch.

Modified version:
- Removed ALL WD regularization functionality (poly.wd_regularization, sampling points, precompute_matrices, etc.)
- Kept MixUp (EnhancedMixUp) functionality
- Kept plotting/logging structure (reg_terms now recorded as 0.0)
"""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random
import pickle
import matplotlib.pyplot as plt

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2


import argparse
import time

from model import get_model
from data import get_data, make_planeloader
from utils import (
    get_loss_function,
    get_scheduler,
    get_random_images,
    produce_plot,
    get_noisy_images,
    AttackPGD,
)
from evaluation import train, test, test_on_trainset, decision_boundary, test_on_adv
from options import options
from utils import (
    simple_lapsed_time,
    adjust_learning_rate,
    adjust_lambda_reg_linear,
    adjust_lambda_reg_sin,
)
from tqdm import tqdm

from check_gpu import print_used_gpus
from set_seed import set_seed, set_seed_detailed
from sam import SAM

def train_one_epoch(args, net, trainloader, optimizer, criterion, device):
    """
    Train for one epoch.
    - No MixUp
    - Supports: normal optimizer OR SAM/ASAM via optimizer.step(closure)
    """
    net.train()
    train_loss_total = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)

        if args.sam:
            # To log training loss/accuracy, we save the outputs from the first closure call (on weight $w$).
            first_forward = {"done": False, "out": None}

            def closure():
                optimizer.zero_grad(set_to_none=True)
                out = net(inputs)
                loss = criterion(out, targets)
                loss.backward()

                # Only record the first output (on weight $w$) for accuracy calculation
                if not first_forward["done"]:
                    first_forward["out"] = out.detach()
                    first_forward["done"] = True
                return loss

            loss = optimizer.step(closure)  # Your SAM.step returns the loss from the first closure call
            loss_value = loss.item()

            with torch.no_grad():
                out_for_acc = first_forward["out"]
                pred = out_for_acc.argmax(dim=1)
                correct += pred.eq(targets).sum().item()
                total += targets.size(0)

        else:
            optimizer.zero_grad(set_to_none=True)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            loss_value = loss.item()
            pred = outputs.argmax(dim=1)
            correct += pred.eq(targets).sum().item()
            total += targets.size(0)

        train_loss_total += loss_value

        if args.dryrun:
            break

    train_acc = 100.0 * correct / max(total, 1)
    train_loss = train_loss_total / max(len(trainloader), 1)
    return train_acc, train_loss

def split_params_bias(model):
    weight_params, bias_params = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.endswith(".bias"):
            bias_params.append(p)
        else:
            weight_params.append(p)
    return weight_params, bias_params

def plot_training_curves(
    train_accs,
    test_accs,
    train_losses,
    test_losses,
    reg_terms,
    save_net_name,
    save_path="training_plots",
):
    """Plot training curves."""
    epochs = range(1, len(train_accs) + 1)

    plt.figure(figsize=(12, 5))

    # Accuracy subplot
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_accs, "b-", label="Training Accuracy")
    plt.plot(epochs, test_accs, "r-", label="Test Accuracy")
    plt.title("Training and Test Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)

    # Loss subplot
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_losses, "g-", label="Training Loss")
    plt.plot(epochs, test_losses, "orange", label="Test Loss")

    # reg_terms is kept for compatibility; with WD-reg removed, it's all zeros.
    if any(reg_terms):
        plt.plot(epochs, reg_terms, "m--", label="Regularization Term")

    plt.title("Training and Test Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f"{save_net_name}.png"))
    plt.close()

    # Save raw data
    data = {
        "train_accs": train_accs,
        "test_accs": test_accs,
        "train_losses": train_losses,
        "test_losses": test_losses,
        "reg_terms": reg_terms,
    }
    with open(os.path.join(save_path, f"{save_net_name}_data.pkl"), "wb") as f:
        pickle.dump(data, f)


if __name__ == "__main__":
    args = options().parse_args()
    set_seed(args.set_seed)

    print("Args:")
    for k, v in vars(args).items():
        print("\t{}: {}".format(k, v))

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Create directory for plots
    os.makedirs("training_plots", exist_ok=True)

    print_used_gpus()

    # Init logs
    train_accs = []
    test_accs = []
    train_losses = []
    test_losses = []
    reg_terms = []  # WD-reg removed, will log 0.0
    num_classes = 10  # CIFAR-10

    # Data
    trainloader, testloader = get_data(args)

    # Raw trainset (no augmentation) if needed
    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    raw_trainset = torchvision.datasets.CIFAR10(
        root="~/data", train=True, download=True, transform=transform_test
    )
    raw_trainloader = torch.utils.data.DataLoader(
        raw_trainset, batch_size=args.bs, shuffle=True, num_workers=16
    )

    aug = args.use_data_aug
    use_train_loader = trainloader if aug else raw_trainloader

    set_seed(args.set_seed)

    # Model
    net = get_model(args, device)

    test_acc, predicted = test(args, net, testloader, device, 0)
    print("scratch prediction ", test_acc)

    # Loss
    criterion = get_loss_function(args)

    # Optimizer
    if args.opt.lower() == "sgd":
        if args.sam:
            base_optimizer = optim.SGD
            if args.adaptive_sam:
                print("Using ASAM (bias NOT adaptive)")
                rho = args.rho  
                w_params, b_params = split_params_bias(net)
                optimizer = SAM(
                    [
                        {"params": w_params, "adaptive": True,  "rho": rho},
                        {"params": b_params, "adaptive": False, "rho": rho},  # bias not adaptive
                    ],
                    base_optimizer,
                    lr=args.lr,
                    momentum=0.9,
                    weight_decay=args.weight_decay,
                )
            else:
                print("Using SAM")
                rho = args.rho
                optimizer = SAM(
                    net.parameters(),
                    base_optimizer,
                    rho=rho,
                    adaptive=False,
                    lr=args.lr,
                    momentum=0.9,
                    weight_decay=args.weight_decay,
                )
        else:
            optimizer = optim.SGD(
                net.parameters(),
                lr=args.lr,
                momentum=0.9,
                weight_decay=args.weight_decay,
            )

    elif args.opt.lower() == "adamw":
        if args.sam:
            base_optimizer = optim.AdamW
            if args.adaptive_sam:
                print("Using ASAM (bias NOT adaptive)")
                rho = args.rho
                w_params, b_params = split_params_bias(net)
                optimizer = SAM(
                    [
                        {"params": w_params, "adaptive": True,  "rho": rho},
                        {"params": b_params, "adaptive": False, "rho": rho},
                    ],
                    base_optimizer,
                    lr=args.lr,
                    weight_decay=args.weight_decay,
                )
            else:
                print("Using SAM")
                rho = args.rho
                optimizer = SAM(
                    net.parameters(),
                    base_optimizer,
                    rho=rho,
                    adaptive=False,
                    lr=args.lr,
                    weight_decay=args.weight_decay,
                )
        else:
            optimizer = optim.AdamW(
                net.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

    elif args.opt.lower() == "adam":
        
        optimizer = optim.Adam(
            net.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )
    else:
        raise ValueError(f"Unknown optimizer: {args.opt}")




    print("Training the network or loading the network")

    start = time.time()
    best_acc = 0.0

    if args.load_net is None:
        for epoch in range(args.epochs):
            lr = adjust_learning_rate(optimizer, epoch + 1, args)

            # Train (NO WD-reg)
            train_acc, train_loss = train_one_epoch(
                args, net, use_train_loader, optimizer, criterion, device
            )
            train_accs.append(train_acc)
            train_losses.append(train_loss)
            reg_terms.append(0.0)

            # Test acc
            test_acc, predicted = test(args, net, testloader, device, epoch)
            test_accs.append(test_acc)

            # Test loss
            net.eval()
            test_loss = 0.0
            with torch.no_grad():
                for data, target in testloader:
                    data, target = data.to(device), target.to(device)
                    output = net(data)
                    test_loss += criterion(output, target).item()
            test_loss /= len(testloader)
            test_losses.append(test_loss)
            net.train()

            print(
                f"EPOCH: {epoch}/{args.epochs}, LR: {lr:.6f}, "
                f"Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}, "
                f"Train loss: {train_loss:.5f}, Test loss: {test_loss:.5f}"
            )

            if args.dryrun:
                break

            # Plot every 5 epochs
            if epoch % 5 == 0:
                plot_training_curves(
                    train_accs,
                    test_accs,
                    train_losses,
                    test_losses,
                    reg_terms,
                    save_net_name=args.save_net,
                    save_path="training_plots",
                )

            # Save checkpoint logic (unchanged path)
            model_path = f"saved_models/wd_reg/{str(args.set_seed)}/{args.save_net}"
            if test_acc > best_acc:
                print(f"The best epoch is: {epoch}")
                os.makedirs(model_path, exist_ok=True)
                print(f"{model_path}/{args.save_net}.pth")
                best_acc = test_acc
                # NOTE: your original code printed path but didn't actually torch.save()
                # If you want to save, uncomment below:
                # if torch.cuda.device_count() > 1 and isinstance(net, torch.nn.DataParallel):
                #     torch.save(net.module.state_dict(), f"{model_path}/{args.save_net}.pth")
                # else:
                #     torch.save(net.state_dict(), f"{model_path}/{args.save_net}.pth")

        # Final plot
        plot_training_curves(
            train_accs,
            test_accs,
            train_losses,
            test_losses,
            reg_terms,
            save_net_name=args.save_net,
            save_path="training_plots",
        )

    else:
        # Load model
        if isinstance(net, torch.nn.DataParallel):
            net = net.module
        net.load_state_dict(torch.load(args.load_net))

    end = time.time()
    simple_lapsed_time("Time taken to train the model", end - start)
