'''Train CIFAR10 with PyTorch.'''
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random
import pickle
import matplotlib.pyplot as plt  

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2
from EnhancedMixup import EnhancedMixUp
import os
import argparse
import time

from model import get_model
from data import get_data, make_planeloader
from utils import get_loss_function, get_scheduler, get_random_images, produce_plot, get_noisy_images, AttackPGD
from evaluation import train, test, test_on_trainset, decision_boundary, test_on_adv
from options import options
from utils import simple_lapsed_time, adjust_learning_rate, adjust_lambda_reg_linear, adjust_lambda_reg_sin
from tqdm import tqdm

from poly.wd_regularization import compute_regularization as wd_reg
from poly.wd_regularization import PolynomialRegularization
from poly.wd_regularization import precompute_matrices
from check_gpu import print_used_gpus
from set_seed import set_seed, set_seed_detailed
from sam import SAM

def train_wd_reg_optimized(args, net, trainloader, optimizer, criterion, device, n_classes, num_pairs, lambda_reg):
    """
    Optimized version: Consolidates the forward pass for standard inputs 
    and regularization sampling points.
    """
    net.train()
    train_loss_total = 0
    output_norm_total = 0
    correct = 0
    total = 0
    # lambda_reg = args.lambda_reg
    max_degree = args.max_degree
    resolution = int(args.resolution)
    miu = args.miu
    have_const = not args.remove_const
    use_norm = args.use_norm
    train_loss_max = np.log(n_classes)
    reg_term_total = 0
    
    # Add exponential moving average for train loss
    ema_alpha = 0.1
    ema_train_loss = None

    # Get precomputed alpha_values (only need to get once during the entire training)
    cached = precompute_matrices(resolution, max_degree)
    alpha_values = cached['alpha_values']
    
    # Prepare alpha_values_sample (device-independent, will move to the correct device later)
    alpha_values_sample = alpha_values.view(-1, 1, 1, 1)

    # if args.label:
        # alpha_values_sample = (alpha_values_sample + 1) * 0.5  
    # else:
    # alpha_values_sample = -0.1 + (alpha_values_sample + 1) * 0.6
    alpha_values_sample = (alpha_values_sample + 1) * 0.5
    
    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)
        raw_targets = targets
        if args.mixup_alpha > 0:
            inputs, targets = mixup(inputs, targets)
            
        optimizer.zero_grad()
        
        
        # Always compute regularization for monitoring, but the actual number of sample pairs used depends on lambda_reg
        compute_reg_for_monitor = True  # Always compute for monitoring
        compute_reg_for_loss = (lambda_reg > -0.00001)  # Only compute for loss when lambda_reg > 0
        # compute_reg_for_loss = (lambda_reg > 0)  # Only compute for loss when lambda_reg > 0
        # Determine the actual number of sample pairs used
        if compute_reg_for_loss:
            # For optimization: use the specified num_pairs
            num_pairs_actual = num_pairs
        else:
            # For monitoring only: use 1 pair to reduce computation
            num_pairs_actual = 1
        
        # Check if batch size is sufficient
        can_compute_reg = (inputs.size(0) >= num_pairs_actual * 2)
        
        reg_term = torch.tensor(0.0, device=device)
        
        if can_compute_reg:
            indices = torch.randperm(inputs.size(0), device=device)[:num_pairs_actual * 2]
            # Split indices into two groups to form pairs
            x1_indices = indices[:num_pairs_actual]
            x2_indices = indices[num_pairs_actual:num_pairs_actual * 2]
            
            if args.label:
                label1_batch = targets[x1_indices]  # [num_pairs_actual]
                label2_batch = targets[x2_indices]  # [num_pairs_actual]
            
            # 2. Generate sampling points for each sample pair
            x1_batch = inputs[x1_indices]  # [num_pairs_actual, C, H, W]
            x2_batch = inputs[x2_indices]  # [num_pairs_actual, C, H, W]
            
            # Move alpha_values_sample to the correct device
            alpha_values_sample_device = alpha_values_sample.to(device)
            
            # Expand dimensions for broadcasting
            x1_expanded = x1_batch.unsqueeze(1)  # [num_pairs_actual, 1, C, H, W]
            x2_expanded = x2_batch.unsqueeze(1)  # [num_pairs_actual, 1, C, H, W]
            alpha_expanded = alpha_values_sample_device.unsqueeze(0)  # [1, resolution, 1, 1, 1]
            
            # Generate all sampling points [num_pairs_actual, resolution, C, H, W]
            samples = x1_expanded + alpha_expanded * (x2_expanded - x1_expanded)
            
            # Reshape sampling points to [num_pairs_actual * resolution, C, H, W]
            samples_flat = samples.reshape(-1, *inputs.shape[1:])
            
            # 3. Concatenate all inputs: normal inputs + sampling points
            all_inputs = torch.cat([inputs, samples_flat], dim=0)
            
            # 4. Forward pass all at once
            all_outputs = net(all_inputs)
            
            # 5. Split outputs
            batch_size = inputs.size(0)
            outputs = all_outputs[:batch_size]  # Normal inputs' outputs
            samples_outputs = F.softmax(all_outputs[batch_size:], dim=1)  # Sampling points' outputs [num_pairs_actual * resolution, num_classes]

            # 6. Compute normal loss
            loss = criterion(outputs, targets)
            
            # 7. Compute regularization term
            # Reshape sampling points' outputs to [num_pairs_actual, resolution, num_classes]
            samples_outputs_reshaped = samples_outputs.reshape(num_pairs_actual, resolution, -1)

            if args.label:
                # Create a full one-hot encoded tensor
                num_classes = samples_outputs_reshaped.size(2)
                one_hot_full = torch.zeros_like(samples_outputs_reshaped)
                
                # Create indices for all sample pairs
                batch_indices = torch.arange(num_pairs_actual, device=device)
                
                # Set one-hot for the first endpoint (α=0 -> label1)
                one_hot_full[batch_indices, 0] = 0.0
                one_hot_full[batch_indices, 0, label1_batch] = 1.0
                
                # Set one-hot for the second endpoint (α=1 -> label2)  
                one_hot_full[batch_indices, -1] = 0.0
                one_hot_full[batch_indices, -1, label2_batch] = 1.0
                
                # Create a mask to mark positions to be replaced [num_pairs_actual, resolution, 1]
                mask = torch.zeros_like(samples_outputs_reshaped[:, :, :1])  # Broadcast along class dimension
                mask[:, 0] = 1.0    # All sample pairs' first sampling point
                mask[:, -1] = 1.0   # All sample pairs' last sampling point
                
                # Fusion: use one-hot at endpoint positions, keep original outputs at other positions
                samples_outputs_modified = (
                    samples_outputs_reshaped * (1.0 - mask) + 
                    one_hot_full * mask
                )
                

                reg_terms_list = []
                for i in range(num_pairs_actual):
                    current_reg_term = PolynomialRegularization.apply(
                        alpha_values.cpu(), 
                        samples_outputs_modified[i].cpu(),  # Use modified outputs
                        samples_outputs_modified[i].cpu(),  # Use modified outputs
                        resolution, 
                        miu, 
                        max_degree, 
                        have_const,
                        use_norm
                    )
                    reg_terms_list.append(current_reg_term)


            else:
                reg_terms_list = []
                for i in range(num_pairs_actual):
                    current_reg_term = PolynomialRegularization.apply(
                        alpha_values.cpu(), 
                        samples_outputs_reshaped[i].cpu(), 
                        resolution, 
                        miu, 
                        max_degree, 
                        have_const,
                        use_norm
                    )
                    reg_terms_list.append(current_reg_term)


            reg_term = torch.mean(torch.stack(reg_terms_list))
        else:
            # Cannot compute regularization term
            print(f"Batch size too small to compute regularization term (need at least {num_pairs_actual * 2} samples, current is {inputs.size(0)})")
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        
        # Update statistics
        output_norm = torch.linalg.norm(outputs, dim=1).mean().item()
        output_norm_total += output_norm
        train_loss_total += loss.item()
        
        if batch_idx == 0:
            ema_train_loss = loss.item()
        else:
            ema_train_loss = ema_alpha * loss.item() + (1 - ema_alpha) * ema_train_loss
        
        # Apply regularization term: only add to loss if lambda_reg > 0 and can compute
        if compute_reg_for_loss and can_compute_reg:
            if args.reg_anneal:
                lambda_reg_current = max(args.min_lambda_reg, args.lambda_reg * ema_train_loss / train_loss_max)
            else:
                lambda_reg_current = lambda_reg
            loss = loss + lambda_reg_current * reg_term
        
        # Always record regularization term for monitoring
        reg_term_total += reg_term.item()
        
        loss.backward()
        
        optimizer.step()
        
        _, predicted = outputs.max(1)
        if 'kl' in args.criterion:
            _, targets = targets.max(1)
        total += targets.size(0)
        if args.mixup_alpha > 0:
            correct = 0
        else:
            correct += predicted.eq(raw_targets).sum().item()
        if args.dryrun:
            break
    
    # Calculate average statistics
    avg_reg_term = reg_term_total / len(trainloader)
    avg_output_norm = output_norm_total / len(trainloader)

    return 100. * correct / total, train_loss_total / len(trainloader), avg_reg_term, lambda_reg, avg_output_norm



def plot_training_curves(train_accs, test_accs, train_losses, test_losses, reg_terms, save_path='training_plots'):
    epochs = range(1, len(train_accs) + 1)
    
    # Create the first plot: Training and Test Accuracy
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    plt.plot(epochs, test_accs, 'r-', label='Test Accuracy')
    plt.title('Training and Test Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Create the second plot: Loss and Regularization Term
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_losses, 'g-', label='Training Loss (w/o reg)')
    plt.plot(epochs, test_losses, 'orange', label='Test Loss')
    
    # If there is a regularization term, add its curve
    if any(reg_terms):
        plt.plot(epochs, reg_terms, 'm--', label='Regularization Term')
    
    plt.title('Training and Test Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    # Ensure the save directory exists
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f'{args.save_net}.png'))
    plt.close()
    
    # Save data for later analysis
    data = {
        'train_accs': train_accs,
        'test_accs': test_accs,
        'train_losses': train_losses,
        'test_losses': test_losses,
        'reg_terms': reg_terms
    }
    
    with open(os.path.join(save_path, f'{args.save_net}_data.pkl'), 'wb') as f:
        pickle.dump(data, f)


if __name__ == "__main__":

    args = options().parse_args()
    set_seed(args.set_seed)
    # print(args)
    print('Args:')
    for k, v in vars(args).items():
        print('\t{}: {}'.format(k, v))
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Create directory to save training plots
    os.makedirs('training_plots', exist_ok=True)
    save_path = args.save_net
    print_used_gpus()
    # Initialize recording lists
    train_accs = []
    test_accs = []
    train_losses = []  # Record training loss (without regularization)
    test_losses = []   # Record test loss
    reg_terms = []     # Record regularization term
    num_classes = 10  # Number of classes in CIFAR-10

    # Data/other training stuff
    # torch.manual_seed(args.set_data_seed)
    trainloader, testloader = get_data(args)
    # torch.manual_seed(args.set_seed)
    # set_seed(args.set_seed)
    # get train set data
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    raw_trainset = torchvision.datasets.CIFAR10(
        root='~/data', train=True, download=True, transform=transform_test)
    raw_trainloader = torch.utils.data.DataLoader(
        raw_trainset, batch_size=args.bs, shuffle=True, num_workers=16)
    
    aug = args.use_data_aug
    if aug:
        use_train_loader = trainloader
    else:
        use_train_loader = raw_trainloader
    
    set_seed(args.set_seed)
    test_accs = []
    train_accs = []
    net = get_model(args, device)

    test_acc, predicted = test(args, net, testloader, device, 0)
    print("scratch prediction ", test_acc)
    # test_accs.append(test_acc)

    criterion = get_loss_function(args)
    if args.opt.lower() == 'sgd':
        if args.sam:
            base_optimizer = optim.SGD
            if args.adaptive_sam:
                print("Using Adaptive SAM optimizer")
                rho = 2.0
                optimizer = SAM(net.parameters(), base_optimizer, rho, adaptive=True, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, rho=rho)
            else:
                print("Using SAM optimizer")
                rho = 0.05
                optimizer = SAM(net.parameters(), base_optimizer, rho, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
            # scheduler = get_scheduler(args, optimizer)
        else:
            optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)

    elif args.opt.lower() == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    elif args.opt.lower() == 'adamw':
        if args.sam:
            base_optimizer = optim.AdamW
            if args.adaptive_sam:
                print("Using Adaptive SAM optimizer")
                rho = 2.0
                optimizer = SAM(net.parameters(), base_optimizer, rho, adaptive=True, lr=args.lr, weight_decay=args.weight_decay)
            else:
                print("Using SAM optimizer")
                rho = 0.05
                optimizer = SAM(net.parameters(), base_optimizer, rho, lr=args.lr, weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.AdamW(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    if args.mixup_alpha > 0:
        print("Using Mixup Augmentation with alpha =", args.mixup_alpha)
        # mixup = v2.MixUp(alpha=args.mixup_alpha, num_classes=num_classes)
        mixup = EnhancedMixUp(
            alpha=args.mixup_alpha,
            num_classes=num_classes,
            resolution=int(args.resolution),
            pairs=None
        )

    # Train or load base network
    print("Training the network or loading the network")

    start = time.time()
    best_acc = 0  # best test accuracy
    best_epoch = 0
    
    if args.load_net is None:
        for epoch in range(args.epochs):
            lr = adjust_learning_rate(optimizer, epoch+1, args)
            lambda_reg = adjust_lambda_reg_sin(epoch, args)
            train_acc, train_loss, reg_term, lambda_reg, output_norm = train_wd_reg_optimized(args, net, use_train_loader, optimizer, criterion, device, num_classes, args.nums_pairs, lambda_reg)
            train_accs.append(train_acc)
            train_losses.append(train_loss)
            reg_terms.append(reg_term)

            test_acc, predicted = test(args, net, testloader, device, epoch)
            test_accs.append(test_acc)
            
            # Calculate test loss
            net.eval()
            test_loss = 0
            with torch.no_grad():
                for data, target in testloader:
                    data, target = data.to(device), target.to(device)
                    output = net(data)
                    test_loss += criterion(output, target).item()
            test_loss /= len(testloader)
            test_losses.append(test_loss)
            
            net.train()

            print(f'EPOCH: {epoch}/{args.epochs}, LR: {lr:.6f}, Reg: {lambda_reg:.5f}, Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}, Train loss: {train_loss:.5f}, Test loss: {test_loss:.5f}, Reg term: {reg_term:.5f}, Output norm: {output_norm:.5f}')

            if args.dryrun:
                break
            # if args.opt == 'SGD':
                # scheduler.step()

            # Plot training curves at the end of each epoch
            if epoch % 5 == 0:  # Save plots every 5 epochs
                plot_training_curves(train_accs, test_accs, train_losses, test_losses, reg_terms)

            # Save checkpoint.
            model_path = f'saved_models/wd_reg/{str(args.set_seed)}/{args.save_net}'
            if test_acc > best_acc:
                print(f'The best epoch is: {epoch}')
                os.makedirs(model_path, exist_ok=True)
                print(f'{model_path}/{args.save_net}.pth')
                # if torch.cuda.device_count() > 1:
                #     torch.save(net.module.state_dict(),
                #             f'{model_path}/{args.save_net}.pth')
                # else:
                #     torch.save(net.state_dict(),
                #             f'{model_path}/{args.save_net}.pth')
                best_acc = test_acc
                best_epoch = epoch
                
            if (epoch + 1) % 50 == 0:
                print(f'{model_path}/{args.save_net}_{epoch}.pth')
                if torch.cuda.device_count() > 1:
                    torch.save(net.module.state_dict(),
                            f'{model_path}/{args.save_net}_{epoch}.pth')
                else:
                    torch.save(net.state_dict(),
                            f'{model_path}/{args.save_net}_{epoch}.pth')

        # Save final training plots after training is complete
        plot_training_curves(train_accs, test_accs, train_losses, test_losses, reg_terms)

    else:
        # Before loading the model, check if it is wrapped in DataParallel
        if isinstance(net, torch.nn.DataParallel):
            net = net.module  # Remove DataParallel wrapper
        net.load_state_dict(torch.load(args.load_net))
        

    end = time.time()
    simple_lapsed_time("Time taken to train the model", end-start)