import argparse
import shutil
import os
import time
import math
import torch
import warnings
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import numpy as np
from models.VGG_models import *
from models.WideResNet import *
import data_loaders
from functions import TET_loss, seed_all, get_logger
from models.resnet_models import resnet19
from models.ResNet import ResNet19


os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


parser = argparse.ArgumentParser(description='PyTorch Temporal Efficient Training')
parser.add_argument('-j',
                    '--workers',
                    default=16,
                    type=int,
                    metavar='N',
                    help='number of data loading workers (default: 10)')
parser.add_argument('--epochs',
                    default=300,
                    type=int,
                    metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start_epoch',
                    default=0,
                    type=int,
                    metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b',
                    '--batch_size',
                    default=64,
                    type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr',
                    '--learning_rate',
                    default=0.1,
                    type=float,
                    metavar='LR',
                    help='initial learning rate',
                    dest='lr')
parser.add_argument('--seed',
                    default=1000,
                    type=int,
                    help='seed for initializing training. ')
parser.add_argument('-T',
                    '--time',
                    default=8,
                    type=int,
                    metavar='N',
                    help='snn simulation time (default: 2)')
parser.add_argument('--means',
                    default=1.0,
                    type=float,
                    metavar='N',
                    help='make all the potential increment around the means (default: 1.0)')
parser.add_argument('--TET',
                    default=True,
                    type=bool,
                    metavar='N',
                    help='if use Temporal Efficient Training (default: True)')
parser.add_argument('--lamb',
                    default=1e-3,
                    type=float,
                    metavar='N',
                    help='adjust the norm factor to avoid outlier (default: 0.0)')
args = parser.parse_args()


def train(model, device, train_loader, criterion, optimizer, beta, args):
    running_loss = 0
    start_time = time.time()
    model.train()
    M = len(train_loader)
    total = 0
    correct = 0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        labels = labels.to(device)
        images = images.to(device)
        outputs, v = model(images)
        outputs = outputs.to(device)
        # v = v.to(device)
        mean_out = outputs.mean(1)
        if args.TET:
            loss = TET_loss(outputs,labels,criterion,args.means,args.lamb,beta,v)
        else:
            loss = criterion(mean_out,labels)
        running_loss += loss.item()
        #print(v.item())
        loss.mean().backward()
        optimizer.step()
        total += float(labels.size(0))
        _, predicted = mean_out.cpu().max(1)
        correct += float(predicted.eq(labels.cpu()).sum().item())
    return running_loss, 100 * correct / total


def PGD(model, images, labels, eps=2/255, alpha=1/255, iters=2):
    r"""
    PGD (Projected Gradient Descent) Attack
    Args:
        model: The target model to attack.
        images: The input images.
        labels: The true labels of the images.
        eps: The maximum perturbation (epsilon).
        alpha: The step size for each iteration.
        iters: The number of iterations.
    Returns:
        adv_images: The adversarial examples generated by PGD.
    """
    model.eval()
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)

    loss = nn.CrossEntropyLoss()

    # Initialize adversarial examples as the original images
    adv_images = images.clone().detach()

    # Add small random noise to the initial images (optional, improves attack diversity)
    adv_images = adv_images + torch.empty_like(adv_images).uniform_(-eps, eps)
    adv_images = torch.clamp(adv_images, min=0, max=1).detach()

    for _ in range(iters):
        adv_images.requires_grad = True
        outputs,_= model(adv_images)
        outputs = outputs.mean(1)  # Assuming the same mean operation as in FGSM
        cost = loss(outputs, labels)

        # Compute gradients
        grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0]

        # Update adversarial examples with gradient ascent
        adv_images = adv_images.detach() + alpha * grad.sign()

        # Project adversarial examples back to the epsilon-ball and valid range
        # adv_images = torch.clamp(adv_images, min=0, max=1)
        # adv_images = torch.max(torch.min(adv_images, images + eps), images - eps).detach()
        delta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()

    return adv_images



def at_train(model, device, train_loader, criterion, optimizer, beta, args):
    running_loss = 0
    model.train()
    total = 0
    correct = 0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        labels = labels.to(device)
        images = images.to(device)

        images = PGD(model, images, labels)

        model.train()
        outputs, v = model(images)

        outputs = outputs.to(device)
        mean_out = outputs.mean(1)

        if args.TET:
            loss = TET_loss(outputs, labels, criterion, args.means, args.lamb, beta, v)
        else:
            loss = criterion(mean_out, labels)
        running_loss += loss.item()
        loss.mean().backward()
        optimizer.step()

        total += float(labels.size(0))
        _, predicted = mean_out.cpu().max(1)
        correct += float(predicted.eq(labels.cpu()).sum().item())
    return running_loss, 100 * correct / total


def rat_train(model, device, train_loader, criterion, optimizer,  beta, args):
    running_loss = 0
    model.train()
    total = 0
    correct = 0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        labels = labels.to(device)
        images = images.to(device)

        images = PGD(model, images, labels)
        model.train()
        outputs, v = model(images)
        outputs = outputs.to(device)
        mean_out = outputs.mean(1)

        if args.TET:
            loss = TET_loss(outputs, labels, criterion, args.means, args.lamb, beta, v)
        else:
            loss = criterion(mean_out, labels)
        running_loss += loss.item()
        loss.mean().backward()
        optimizer.step()

        orthogonal_retraction(model, 0.001)
        convex_constraint(model)

        total += float(labels.size(0))
        _, predicted = mean_out.cpu().max(1)
        correct += float(predicted.eq(labels.cpu()).sum().item())
    return running_loss, 100 * correct / total





@torch.no_grad()
def test(model, test_loader, device):
    correct = 0
    v0=0
    total = 0
    model.eval()
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        outputs, v = model(inputs)
        v0=v0+v
        mean_out = outputs.mean(1)
        _, predicted = mean_out.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())
        if batch_idx % 100 == 0:
            acc = 100. * float(correct) / float(total)
            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
    final_acc = 100 * correct / total
    print(v0)
    return final_acc

def orthogonal_retraction(model, beta=0.002):
    with torch.no_grad():
        for module in model.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                if isinstance(module, nn.Conv2d):
                    weight_ = module.weight.data
                    sz = weight_.shape
                    weight_ = weight_.reshape(sz[0],-1)
                    rows = list(range(module.weight.data.shape[0]))
                elif isinstance(module, nn.Linear):
                    if module.weight.data.shape[0] < 200: # set a sample threshold for row number
                        weight_ = module.weight.data
                        sz = weight_.shape
                        weight_ = weight_.reshape(sz[0], -1)
                        rows = list(range(module.weight.data.shape[0]))
                    else:
                        rand_rows = np.random.permutation(module.weight.data.shape[0])
                        rows = rand_rows[: int(module.weight.data.shape[0] * 0.3)]
                        weight_ = module.weight.data[rows,:]
                        sz = weight_.shape
                module.weight.data[rows,:] = ((1 + beta) * weight_ - beta * weight_.matmul(weight_.t()).matmul(weight_)).reshape(sz)


def convex_constraint(model):
    with torch.no_grad():
        for module in model.modules():
            if isinstance(module, ConvexCombination):
                comb = module.comb.data
                alpha = torch.sort(comb, descending=True)[0]
                k = 1
                for j in range(1,module.n+1):
                    if (1 + j * alpha[j-1]) > torch.sum(alpha[:j]):
                        k = j
                    else:
                        break
                gamma = (torch.sum(alpha[:k]) - 1)/k
                module.comb.data -= gamma
                torch.relu_(module.comb.data)


if __name__ == '__main__':
    seed_all(args.seed)
    # train_dataset, val_dataset = data_loaders.build_dvscifar('')
    # train_dataset, val_dataset, znorm = data_loaders.cifar_dataset(use_cifar10=True)
    train_dataset, val_dataset, znorm = data_loaders.build_tinyimagenet()
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
                                              shuffle=False, num_workers=args.workers, pin_memory=True)

    model = ResNet19(num_classes=200, norm=znorm)

    parallel_model = torch.nn.DataParallel(model)
    parallel_model.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    best_acc = 0
    best_epoch = 0
    beta = 0
    beta_max = 36/17/2    # resnet19 17| vgg11 10| wrn16 13
    logger = get_logger('exp.log')
    logger.info('start training!')
    
    for epoch in range(args.epochs + 1):
        beta = 0.5 * beta_max * (1 - math.cos(math.pi * epoch / args.epochs))
        loss, acc = at_train(parallel_model, device, train_loader, criterion, optimizer, beta, args)
        logger.info('Epoch:[{}/{}]\t beta={:.5f}\t loss={:.5f}\t acc={:.3f}'.format(epoch , args.epochs, beta, loss, acc ))
        scheduler.step()
        facc = test(parallel_model, test_loader, device)
        logger.info('Epoch:[{}/{}]\t Test acc={:.3f}'.format(epoch , args.epochs, facc ))

        if best_acc < facc:
            best_acc = facc
            best_epoch = epoch
            torch.save(parallel_model.module.state_dict(), 'resnet19_noise_lag36_at.pth')
        logger.info('Best Test acc={:.3f}, Best Epoch = {}'.format(best_acc, best_epoch ))

    print('\n')
