
import pickle
import argparse
import time
import math
import torch
import torch.nn as nn
from tqdm import tqdm
from conf import settings
from utils import get_network, get_training_dataloader, get_test_dataloader


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def train(epoch):
    start = time.time()
    net.train()
    training_loss = 0.0
    total = 0
    correct = 0
    gradient_norm_ratio = 0.0
    # First loop for training update with progress bar
    progress_bar_train = tqdm(enumerate(cifar_training_loader), total=len(
        cifar_training_loader), desc=f'Epoch {epoch} Training')
    for batch_index, (images, labels) in progress_bar_train:
        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()  # Zero the gradients after update

        training_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Update progress bar with training loss
        progress_bar_train.set_postfix(loss=training_loss)

    # Zero the training loss for next accumulation
    training_loss = 0.0

    # Second loop for computing full gradient with progress bar
    progress_bar_grad = tqdm(enumerate(cifar_training_loader), total=len(
        cifar_training_loader), desc=f'Epoch {epoch} Gradient Accumulation')
    for batch_index, (images, labels) in progress_bar_grad:
        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        loss = loss_function(outputs, labels)
        loss.backward()  # Accumulate gradients over all batches
        training_loss += loss.item()
        # No progress update needed other than iteration display

    # Calculate gradient norms
    grad_all = torch.cat([p.grad.view(-1)
                         for p in net.parameters() if p.grad is not None])
    grad_l1_norm = grad_all.norm(1)
    grad_l2_norm = grad_all.norm(2)
    norm_ratio = grad_l1_norm / grad_l2_norm
    gradient_norm_ratio = norm_ratio.item()
    optimizer.zero_grad()  # Clear gradients after full computation
    scheduler.step()

    finish = time.time()
    training_accuracy = 100. * correct / total
    print(f'Completed Epoch {epoch}: Loss: {training_loss:.4f}, Accuracy: {training_accuracy:.2f}%, Gradient Norm Ratio: {gradient_norm_ratio:.6f}, Time: {finish - start:.2f}s')

    return training_loss, gradient_norm_ratio, training_accuracy


def evaluate(epoch):
    net.eval()
    correct = 0.0
    total = 0

    with torch.no_grad():
        for n_iter, (image, label) in enumerate(cifar_test_loader):
            if args.gpu:
                image = image.cuda()
                label = label.cuda()

            output = net(image)
            _, predicted = output.max(1)

            total += label.size(0)
            correct += predicted.eq(label).sum().item()

    accuracy = correct / total
    print(f'Epoch {epoch} Evaluation Accuracy: {accuracy:.4f}')
    return accuracy


if __name__ == '__main__':
    # training settings:
    # cifar100 + resnet50 (sqrt=5059): batch_size=2048, lr=1e-4, 4 gpus
    # cifar100 + resnet101 (sqrt=6670): batch_size=2048, lr=1e-4, 4 gpus
    # cifar100 + resnet152 (sqrt=7746): batch_size=2048, lr=1e-4, 6 gpus
    parser = argparse.ArgumentParser()
    parser.add_argument('-net', type=str, default='resnet101', help='net type')
    parser.add_argument('-gpu', action='store_true',
                        default=True, help='use gpu or not')
    parser.add_argument('-b', type=int, default=2048,
                        help='batch size for dataloader')
    parser.add_argument('-lr', type=float, default=1e-4,
                        help='initial learning rate')
    parser.add_argument('-wd', type=float, default=0.01,
                        help='weight decay')
    parser.add_argument('-momentum', type=float, default=0.9, help='momentum parameter for SGD and RMSProp')
    parser.add_argument('-dataset', type=str, default='CIFAR10',
                        choices=['CIFAR10', 'CIFAR100'], help='Dataset type')
    # optimizer
    parser.add_argument('-optimizer', '-opt', type=str, default='rmsprop', help='Optimizer type')

    args = parser.parse_args()
    print(args)

    info_list = [args.dataset, args.net, args.optimizer,
                 'bs'+str(args.b), 'lr'+str(args.lr), 'wd'+str(args.wd), 'momentum'+str(args.momentum)]
    if args.net.startswith('vit'):
        info_list.insert(0, 'pretrain')

    save_file = '_'.join(info_list) + '.pkl'
    print(f'Saving training statistics to {save_file}')

    net = get_network(args).cuda()
    total_params = count_parameters(net)
    print(f"Parameter numbers: {total_params}")
    print(f"Square root of parameter numbers: {math.sqrt(total_params)}")
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for training.")
        net = nn.DataParallel(net)
    else:
        print(f"Using a single GPU for training.")

    cifar_training_loader = get_training_dataloader(
        dataset_type=args.dataset,
        num_workers=8,
        batch_size=args.b,
        shuffle=True,
        arch=args.net
    )
    # print('Transforms:', cifar_training_loader.dataset.transform)
    cifar_test_loader = get_test_dataloader(
        dataset_type=args.dataset,
        num_workers=8,
        batch_size=args.b,
        shuffle=True,
        arch=args.net
    )

    loss_function = nn.CrossEntropyLoss(reduction='sum')

    if args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(
            net.parameters(), lr=args.lr, weight_decay=args.wd)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(
            net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(
            net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    else:
        raise ValueError(f'Invalid optimizer: {args.optimizer}')
    print(optimizer)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=settings.EPOCH)
    training_losses, gradient_norm_ratios, training_accs, eval_accs = [], [], [], []
    for epoch in range(settings.EPOCH):
        training_loss, gradient_norm_ratio, training_accuracy = train(epoch)
        training_losses.append(training_loss)
        gradient_norm_ratios.append(gradient_norm_ratio)
        training_accs.append(training_accuracy)

        accuracy = evaluate(epoch)
        eval_accs.append(accuracy)

    training_data = {
        'training_losses': training_losses,
        'gradient_norm_ratios': gradient_norm_ratios,
        'args': vars(args),  # Save the args as a dictionary
        'training_accuracy': training_accs,
        'evaluation_accuracy': eval_accs
    }

    with open(save_file, 'wb') as f:
        pickle.dump(training_data, f)

    print('Training statistics has been saved to {}'.format(save_file))
