'''Train CIFAR10 with PyTorch.'''
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random
import pickle
import matplotlib.pyplot as plt 

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time

from model import get_model
from data import get_data, make_planeloader
from utils import get_loss_function, get_scheduler, get_random_images, produce_plot, get_noisy_images, AttackPGD
from evaluation import train, test, test_on_trainset, decision_boundary, test_on_adv
from options import options
from utils import simple_lapsed_time, adjust_learning_rate, adjust_lambda_reg_linear, adjust_lambda_reg_sin, label_for_samples, pca, random_weighted_sum
from tqdm import tqdm
from utils import label_for_samples_random, adjust_mask_ratio

from poly.wd_regularization_torch import polynomial_regularization
from poly.wd_regularization_torch import precompute_chebyshev_matrix
from check_gpu import print_used_gpus
from set_seed import set_seed, set_seed_detailed
from sam import SAM
from TimeTracker import TimeTracker
from torch import vmap

def train_wd_reg_optimized(args, net, trainloader, optimizer, criterion, device, n_classes, num_pairs, lambda_reg, mask_ratio):
    """
    Optimized version: Consolidates the forward pass for standard inputs 
    and regularization sampling points.

    Modification: Each sample pair now utilizes a unique alpha sequence.
    """
    net.train()
    time_tracker = TimeTracker()
    cifar_mean = torch.tensor([0.4914, 0.4822, 0.4465], device=device).view(1, 3, 1, 1)
    cifar_std = torch.tensor([0.2023, 0.1994, 0.2010], device=device).view(1, 3, 1, 1)
    time_tracker.start("prepare data")
    train_loss_total = 0
    output_norm_total = 0
    correct = 0
    total = 0
    max_degree = args.max_degree
    resolution = int(args.resolution)
    miu = args.miu
    have_const = not args.remove_const
    use_norm = args.use_norm
    train_loss_max = np.log(n_classes)
    reg_term_total = 0
    
    ema_alpha = 0.1
    ema_train_loss = None
    if not args.random_alpha:
        cached = precompute_chebyshev_matrix(resolution, max_degree, device)
        alpha_values = cached['alpha_values'].to(torch.float32)
        # Slice the inner alpha values in advance for the mixup process
        # Assume alpha_values spans the range $[-1, \dots, 1]$
        if args.label:
            alpha_inner_only = alpha_values[1:-1] 
        else:
            alpha_inner_only = alpha_values
            
        alpha_inner_sample = alpha_inner_only.view(-1, 1, 1, 1)
        
        # Map inner alpha values to the Mixup space
        alpha_inner_sample = (alpha_inner_sample + 1) * 0.5
        
    time_tracker.end("prepare data")
    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):

        time_tracker.start("prepare batch")
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        compute_reg_for_loss = (lambda_reg > -0.00001)  # Only compute regularization loss when lambda_reg > 0
        
        # Check if batch size is sufficient
        can_compute_reg = (inputs.size(0) * (inputs.size(0) - 1) >= num_pairs * 2)
        
        reg_term = torch.tensor(0.0, device=device)
        
        if compute_reg_for_loss and can_compute_reg:
            # Allow repeated sampling of pairs
            N = inputs.size(0)
            x1_indices = torch.randint(low=0, high=N, size=(num_pairs,), device=device)
            # By adding a random offset in [1, N-1] to x1 and taking modulo, ensure x2 != x1
            offset = torch.randint(low=1, high=N, size=(num_pairs,), device=device)
            x2_indices = (x1_indices + offset) % N
            full_alpha_tensor = []
            # --- Modification starts: Decide sampling data based on args.use_noise ---
            # Generate sampling points for each sample pair
            if args.use_noise:
                # Generate noise: [num_pairs, C, H, W], uniform distribution in [0, 1]
                # inputs.shape[1:] automatically gets (C, H, W)
                noise_shape = (num_pairs, *inputs.shape[1:])
                
                raw_noise1 = torch.rand(noise_shape, device=device)
                raw_noise2 = torch.rand(noise_shape, device=device)
                
                # Standardize: simulate Normalize in transform_test
                # (x - mean) / std
                x1_batch = (raw_noise1 - cifar_mean) / cifar_std
                x2_batch = (raw_noise2 - cifar_mean) / cifar_std
                
                # Labels for noise are meaningless, but we still fetch them by index 
                # to maintain code flow (effectively treated as random labels).
                if args.label or args.label_random:  
                    label1_batch = targets[x1_indices]
                    label2_batch = targets[x2_indices]
            else:
                # Original logic: use real images
                x1_batch = inputs[x1_indices]  # [num_pairs, C, H, W]
                x2_batch = inputs[x2_indices]  # [num_pairs, C, H, W]
                
                if args.label or args.label_random:  
                    label1_batch = targets[x1_indices]
                    label2_batch = targets[x2_indices]
            # --- Modification ends ---
            
            
            # ===  Generate Alpha sequence ===
            if args.random_alpha:
                if args.label:
                    n_inner = resolution - 2
                else:
                    n_inner = resolution
                
                # --- Vectorized generation ---
                # 1. Batch generate random jitter [num_pairs, n_inner]
                steps = torch.arange(n_inner, device=device, dtype=torch.float32).unsqueeze(0) # [1, n_inner]
                jitter = torch.rand((num_pairs, n_inner), device=device, dtype=torch.float32)
                
                # 2. Broadcast compute Cheb points
                raw_fraction = (steps + jitter) / n_inner
                if args.label:
                    padding = 1.0 / resolution 
                else:
                    padding = 0.0
                compressed_fraction = padding + raw_fraction * (1 - 2 * padding)
                theta = compressed_fraction * np.pi
                alpha_inner_cheb = -torch.cos(theta)
                
                # 3. Sort (dim=1) to ensure monotonicity
                alpha_inner_cheb, _ = torch.sort(alpha_inner_cheb, dim=1)

                # 4 Construct full Alpha [num_pairs, resolution]
                # Create boundary vectors
                if args.label:
                    ones = torch.ones((num_pairs, 1), device=device)
                    minus_ones = -ones
                    
                    # Concatenate: [-1, inner..., 1]
                    full_alpha_tensor = torch.cat([minus_ones, alpha_inner_cheb, ones], dim=1)
                else:
                    full_alpha_tensor = alpha_inner_cheb
                
                # 5 Construct Inner Mixup Alpha (for generating images)
                # Directly reshape to target dimensions: [num_pairs, n_inner, 1, 1, 1]
                inner_alpha_device = alpha_inner_cheb.view(num_pairs, n_inner, 1, 1, 1)
                
                # Vectorized mapping
                inner_alpha_device = (inner_alpha_device + 1) * 0.5


            else:
                # Non-random mode, directly use precomputed inner
                inner_alpha_device = alpha_inner_sample.to(device).unsqueeze(0)
            
            samples_flat = torch.tensor([], device=device) # Initialize empty
            if inner_alpha_device is not None and inner_alpha_device.numel() > 0:
                x1_expanded = x1_batch.unsqueeze(1) # [num_pairs, 1, C, H, W]
                x2_expanded = x2_batch.unsqueeze(1)
                
                if not args.random_alpha:
                     # Broadcast [1, n_inner, ...] -> [num_pairs, n_inner, ...]
                     inner_alpha_device = inner_alpha_device.expand(num_pairs, -1, -1, -1, -1)

                # Generate inner point images
                inner_samples = x1_expanded + inner_alpha_device * (x2_expanded - x1_expanded)
                samples_flat = inner_samples.reshape(-1, *inputs.shape[1:]) # [num_pairs * n_inner, C, H, W]
                
            
            if samples_flat.numel() > 0:
                all_inputs = torch.cat([inputs, samples_flat], dim=0)
            else:
                all_inputs = inputs
            
            all_outputs = net(all_inputs)
            # 5. Split outputs
            batch_size = inputs.size(0)
            outputs = all_outputs[:batch_size] # Outputs for normal loss
            loss = criterion(outputs, targets)
            
            # Process Inner Outputs
            if samples_flat.numel() > 0:
                inner_outputs_flat = all_outputs[batch_size:]
                inner_probs = F.softmax(inner_outputs_flat, dim=1)
                # Reshape: [num_pairs, n_inner, n_classes]
                if args.label:
                    inner_probs_reshaped = inner_probs.reshape(num_pairs, resolution - 2, -1)
                else:
                    inner_probs_reshaped = inner_probs.reshape(num_pairs, resolution, -1)
            #  Compute normal loss
            loss = criterion(outputs, targets)
            


            
            if args.label:
                endpoint_1 = F.one_hot(label1_batch, num_classes=n_classes).float().unsqueeze(1) # [num_pairs, 1, n_classes]
                endpoint_2 = F.one_hot(label2_batch, num_classes=n_classes).float().unsqueeze(1) # [num_pairs, 1, n_classes]
                full_sequence = torch.cat([endpoint_1, inner_probs_reshaped, endpoint_2], dim=1) # [num_pairs, resolution, n_classes]
            else:
                full_sequence = inner_probs_reshaped  # [num_pairs, resolution, n_classes]

            

            if args.pca_reg > 0:
                full_sequence = pca(full_sequence, num_pairs, args.pca_reg)
                num_classes = full_sequence.shape[-1]


            if args.random_alpha:
                if isinstance(full_alpha_tensor, list):
                    batched_alpha_inputs = torch.stack(full_alpha_tensor) 
                else:
                    batched_alpha_inputs = full_alpha_tensor
                alpha_in_dims = 0 
            else:
                batched_alpha_inputs = alpha_values.to(device)
                alpha_in_dims = None

            #  Define wrapper function (closure) to bind unchanged parameters
            def reg_wrapper(alpha, sample_output):
                return polynomial_regularization(
                    alpha, 
                    sample_output, 
                    resolution, 
                    miu, 
                    max_degree, 
                    have_const,
                    use_norm,
                    args.random_alpha,
                    args.square,
                    args.degree_mode,
                    args.ce_reg
                )


            batched_reg_func = vmap(reg_wrapper, in_dims=(alpha_in_dims, 0))

            # Parallel execution
            # reg_terms_tensor shape: [num_pairs]
            reg_terms_tensor = batched_reg_func(batched_alpha_inputs, full_sequence)

            #  Compute average
            reg_term = torch.mean(reg_terms_tensor)
        else:
            # Unable to compute regularization term
            # print(f"Batch size too small to compute regularization term (need at least {num_pairs * 2} samples, current is {inputs.size(0)})")
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        
        # Update statistics
        output_norm = torch.linalg.norm(outputs, dim=1).mean().item()
        output_norm_total += output_norm
        train_loss_total += loss.item()
        
        if batch_idx == 0:
            ema_train_loss = loss.item()
        else:
            ema_train_loss = ema_alpha * loss.item() + (1 - ema_alpha) * ema_train_loss
        
        if compute_reg_for_loss and can_compute_reg:
            if args.reg_anneal:
                lambda_reg_current = max(args.min_lambda_reg, args.lambda_reg * ema_train_loss / train_loss_max)
            else:
                lambda_reg_current = lambda_reg
            loss = loss + lambda_reg_current * reg_term
            # loss = lambda_reg_current * reg_term
        # Always record regularization term for monitoring
        reg_term_total += reg_term.item()
        
        loss.backward()
        
        optimizer.step()
        
        _, predicted = outputs.max(1)
        if 'kl' in args.criterion:
            _, targets = targets.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if args.dryrun:
            break
        time_tracker.end("prepare batch")
    
    # Compute average statistics
    avg_reg_term = reg_term_total / len(trainloader)
    avg_output_norm = output_norm_total / len(trainloader)
    
    return 100. * correct / total, train_loss_total / len(trainloader), avg_reg_term, lambda_reg, avg_output_norm, time_tracker



def plot_training_curves(train_accs, test_accs, train_losses, test_losses, reg_terms, save_path='training_plots'):
    """Plot training curves"""
    epochs = range(1, len(train_accs) + 1)
    
    # Create the first plot: Training and Test Accuracy
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    plt.plot(epochs, test_accs, 'r-', label='Test Accuracy')
    plt.title('Training and Test Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Create the second plot: Loss and Regularization Term
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_losses, 'g-', label='Training Loss (w/o reg)')
    plt.plot(epochs, test_losses, 'orange', label='Test Loss')
    
    # If there is a regularization term, add its curve
    if any(reg_terms):
        plt.plot(epochs, reg_terms, 'm--', label='Regularization Term')
    
    plt.title('Training and Test Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    # Ensure the save directory exists
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f'{args.save_net}.png'))
    plt.close()
    
    # Save data for later analysis
    data = {
        'train_accs': train_accs,
        'test_accs': test_accs,
        'train_losses': train_losses,
        'test_losses': test_losses,
        'reg_terms': reg_terms
    }
    
    with open(os.path.join(save_path, f'{args.save_net}_data.pkl'), 'wb') as f:
        pickle.dump(data, f)


if __name__ == "__main__":

    args = options().parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(args.gpu_id)
    # Check if resolution is greater than 0, and if in label mode, greater than 2
    if args.resolution <= 0 and args.lambda_reg > 0:
        raise ValueError("Resolution must be greater than 0.")
    if args.label and args.resolution <= 2 and args.lambda_reg > 0:
        raise ValueError("In label mode, resolution must be greater than 2.")
    set_seed(args.set_seed)
    # print(args)
    print('Args:')
    for k, v in vars(args).items():
        print('\t{}: {}'.format(k, v))
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Create directory to save training plots
    os.makedirs('training_plots', exist_ok=True)
    save_path = args.save_net
    print_used_gpus()
    # Initialize record lists
    train_accs = []
    test_accs = []
    train_losses = []  # Record training loss (without regularization)
    test_losses = []   # Record test loss
    reg_terms = []     # Record regularization term
    num_classes = 10  # Number of classes in CIFAR-10

    trainloader, testloader = get_data(args)

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    raw_trainset = torchvision.datasets.CIFAR10(
        root='~/data', train=True, download=True, transform=transform_test)
    raw_trainloader = torch.utils.data.DataLoader(
        raw_trainset, batch_size=args.bs, shuffle=True, num_workers=16)
    
    aug = args.use_data_aug
    if aug:
        use_train_loader = trainloader
    else:
        use_train_loader = raw_trainloader
    
    set_seed(args.set_seed)
    test_accs = []
    train_accs = []
    net = get_model(args, device)

    test_acc, predicted = test(args, net, testloader, device, 0)
    print("scratch prediction ", test_acc)
    # test_accs.append(test_acc)

    criterion = get_loss_function(args)
    if args.opt.lower() == 'sgd':
        if args.sam:
            base_optimizer = optim.SGD
            if args.adaptive_sam:
                print("Using Adaptive SAM optimizer")
                rho = 2.0
                optimizer = SAM(net.parameters(), base_optimizer, rho, adaptive=True, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, rho=rho)
            else:
                print("Using SAM optimizer")
                rho = 0.05
                optimizer = SAM(net.parameters(), base_optimizer, rho, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
            # scheduler = get_scheduler(args, optimizer)
        else:
            optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)

    elif args.opt.lower() == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    elif args.opt.lower() == 'adamw':
        if args.sam:
            base_optimizer = optim.AdamW
            if args.adaptive_sam:
                print("Using Adaptive SAM optimizer")
                rho = 2.0
                optimizer = SAM(net.parameters(), base_optimizer, rho, adaptive=True, lr=args.lr, weight_decay=args.weight_decay)
            else:
                print("Using SAM optimizer")
                rho = 0.05
                optimizer = SAM(net.parameters(), base_optimizer, rho, lr=args.lr, weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.AdamW(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # Train or load base network
    print("Training the network or loading the network")

    start = time.time()
    best_acc = 0  # best test accuracy
    best_epoch = 0
    
    if args.load_net is None:
        for epoch in range(args.epochs):
            lr = adjust_learning_rate(optimizer, epoch+1, args)
            lambda_reg = adjust_lambda_reg_sin(epoch, args)
            mask_ratio = adjust_mask_ratio(epoch, args)
            # Train and get time statistics
            train_acc, train_loss, reg_term, lambda_reg, output_norm, time_tracker = train_wd_reg_optimized(
                args, net, use_train_loader, optimizer, criterion, device, 
                num_classes, args.nums_pairs, lambda_reg, mask_ratio
            )
            
            # Print time statistics
            print(time_tracker.summary())
            
            train_accs.append(train_acc)
            train_losses.append(train_loss)
            reg_terms.append(reg_term)

            test_acc, predicted = test(args, net, testloader, device, epoch)
            test_accs.append(test_acc)
            
            # Compute test loss
            net.eval()
            test_loss = 0
            with torch.no_grad():
                for data, target in testloader:
                    data, target = data.to(device), target.to(device)
                    output = net(data)
                    test_loss += criterion(output, target).item()
            test_loss /= len(testloader)
            test_losses.append(test_loss)
            
            net.train()

            print(f'EPOCH: {epoch}/{args.epochs}, LR: {lr:.6f}, Reg: {lambda_reg:.5f}, Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}, Train loss: {train_loss:.5f}, Test loss: {test_loss:.5f}, Reg term: {reg_term:.5f}, Output norm: {output_norm:.5f}')

            if args.dryrun:
                break
            # if args.opt == 'SGD':
                # scheduler.step()

            # Plot curves after each epoch
            if epoch % 5 == 0:  # Save plots every 5 epochs
                plot_training_curves(train_accs, test_accs, train_losses, test_losses, reg_terms)

            # Save checkpoint.
            model_path = f'saved_models/wd_reg/{str(args.set_seed)}/{args.save_net}'
            if test_acc > best_acc:
                print(f'The best epoch is: {epoch}')
                os.makedirs(model_path, exist_ok=True)
                print(f'{model_path}/{args.save_net}.pth')
                if torch.cuda.device_count() > 1:
                    torch.save(net.module.state_dict(),
                            f'{model_path}/{args.save_net}.pth')
                else:
                    torch.save(net.state_dict(),
                            f'{model_path}/{args.save_net}.pth')
                best_acc = test_acc
                best_epoch = epoch
                
        # After training, save the final plots
        plot_training_curves(train_accs, test_accs, train_losses, test_losses, reg_terms)

    else:
        # Before loading the model, check if it is DataParallel
        if isinstance(net, torch.nn.DataParallel):
            net = net.module  # Remove DataParallel wrapper
        net.load_state_dict(torch.load(args.load_net))
        

    end = time.time()
    simple_lapsed_time("Time taken to train the model", end-start)