import numpy as np
import torch
import utils
import argparse
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import genotypes
from torch.autograd import Variable
from model import NetworkCIFAR
from loss import trades_loss, madry_loss
from utils import madry_generate
import torch.nn.functional as F
from utils import load_tinyimagenet
import logging
import datetime



def load_model_weights(model, checkpoint_path, device='cuda'):

    checkpoint = torch.load(checkpoint_path, map_location=device)


    model.load_state_dict(checkpoint['model'], strict=False)

    print(f"Model weights loaded from {checkpoint_path}")

    return model

parser = argparse.ArgumentParser("cifar")
parser.add_argument('--gpu', type=int, default=4, help='gpu device id')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--epochs', type=int, default=120, help='num of training epochs')
parser.add_argument('--adv_loss', type=str, default='pgd', help='experiment name')
parser.add_argument('--data', type=str, default="", help='location of the data corpus')
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--epsilon', type=float, default=8/255, help='perturbation')
parser.add_argument('--num_steps', type=int, default=7, help='perturb number of steps')
parser.add_argument('--step_size', type=float, default=0.01, help='perturb step size')
parser.add_argument('--beta', type=float, default=6.0, help='regularization in TRADES')
parser.add_argument('--init_channels', type=int, default=32, help='num of init channels')
parser.add_argument('--layers', type=int, default=10, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--data_dir', type=str, default="", help='imagenet-200 datadir')
parser.add_argument('--workers', type=int, default=4, help='data loding workers')
parser.add_argument("--data_name",  choices=["tiny", "cifar100", "cifar10"], default="tiny",
                    help="choose dataset; this decides mean/std")

args = parser.parse_args()

CIFAR_CLASSES = 200

def load_model_weights(model, checkpoint_path, device='cuda'):

    checkpoint = torch.load(checkpoint_path, map_location=device)


    model.load_state_dict(checkpoint['model'], strict=False)

    print(f"Model weights loaded from {checkpoint_path}")

    return model

def main():
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print('gpu device: ', args.gpu)
    print("args: ", args)
    MAX_WEIGHT_LOSS =1.2
    genotype = genotypes.search_cifar10_338_0

    STATS = {
        "tiny": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        "cifar100": ([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),
        "cifar10": ([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
    }

    mean, std = STATS[args.data_name]
    model = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype,
                         mean=mean, std=std)

    print(model)
    print("param size: ", utils.count_parameters_in_MB(model), 'MB')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    model = model.cuda()


    train_queue, valid_queue,_ = load_tinyimagenet(args)

    base_lambda = 2
    best_acc = 0.0
    last_epoch = 0

    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch)
        print('epoch: %d' % epoch)
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs


        lambda_epoch = min(base_lambda * ((epoch + 1) / 60), MAX_WEIGHT_LOSS)
        print(f"lambdaepoch: {lambda_epoch}, MaxWeightLoss : {MAX_WEIGHT_LOSS}")
        if best_acc<70:
            lambda_epoch = min(base_lambda * ((epoch + 1) / 60), MAX_WEIGHT_LOSS)
        if epoch < 20:
            lambda_alp = 0.5 * ((epoch+1) / 20.0)
        else:
            lambda_alp = 0.5

        train_acc, train_obj = train(train_queue, model, criterion, optimizer, args, lambda_epoch, lambda_alp)


        print('train_acc: ', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)

        if valid_acc > best_acc:
            best_acc = valid_acc
            utils.save(model, epoch, optimizer, valid_acc)
            print('Save model!')
        if epoch>70:
            utils.save(model, epoch, optimizer, valid_acc)
            print('Save model!')
        print('valid acc: ', valid_acc, 'best_acc: ', best_acc)



def train(train_queue, model, criterion, optimizer, args, lambda_epoch, lambda_alp, misclass_weight = 1.2 ):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    robust_top1 = utils.AvgrageMeter()  # For robust accuracy
    robust_top5 = utils.AvgrageMeter()  # For robust accuracy
    standard_losses = utils.AvgrageMeter()  # For tracking standard loss
    robust_losses = utils.AvgrageMeter()  # For tracking robust loss
    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)

        input = Variable(input, requires_grad=False).cuda()
        target = Variable(target, requires_grad=False).cuda(non_blocking=True)

        # Generate adversarial example (PGD)
        input_pgd = madry_generate(model, input, target, optimizer,
                               step_size=args.step_size, epsilon=args.epsilon,
                               perturb_steps=args.num_steps)

        logits_standard, logits_aux = model(input)
        # logits_standard = model(input)
        loss_standard = criterion(logits_standard, target)

        # Adversarial loss: loss on the adversarial data
        logits_pgd, _ = model(input_pgd)
        # logits_pgd = model(input_pgd)

        loss_pgd = F.cross_entropy(logits_pgd, target, reduction='none')

        # # Combine the standard loss and the adversarial loss (weighted by lambda_robust)
        # loss = (1 - lambda_robust) * loss_standard + lambda_robust * loss_pgd

        logits_standard_norm = (logits_standard - logits_standard.mean()) / logits_standard.std()
        logits_pgd_norm = (logits_pgd - logits_pgd.mean()) / logits_pgd.std()
        loss_alp = F.mse_loss(logits_standard_norm, logits_pgd_norm.detach())




        # loss_alp = F.mse_loss(logits_standard, logits_pgd.detach())

        preds_pgd = logits_pgd.argmax(dim=1)

        sample_weights = torch.ones_like(target, dtype=torch.float, device=target.device)

        sample_weights[preds_pgd != target] = misclass_weight

        loss_adv_weighted = (sample_weights * loss_pgd).mean()
        print(lambda_epoch)
        loss = loss_standard + lambda_epoch *loss_adv_weighted+ lambda_alp * loss_alp
        if args.auxiliary:
            loss_aux = criterion(logits_aux, target)
            loss += args.auxiliary_weight * loss_aux


        optimizer.zero_grad()
        loss.backward()




        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        # Compute standard accuracy (on original data)
        prec1_standard, prec5_standard = utils.accuracy(logits_standard, target, topk=(1, 5))

        # Compute robust accuracy (on adversarial data)
        prec1_robust, prec5_robust = utils.accuracy(logits_pgd, target, topk=(1, 5))

        # Update standard metrics
        top1.update(prec1_standard.data.item(), n)
        top5.update(prec5_standard.data.item(), n)

        # Update robust metrics
        robust_top1.update(prec1_robust.data.item(), n)
        robust_top5.update(prec5_robust.data.item(), n)

        # Update loss meters
        standard_losses.update(loss_standard.data.item(), n)

        # Update loss
        objs.update(loss.data.item(), n)

        print(f"Step [{step + 1}/{len(train_queue)}] | "
              f"Loss: {loss.item():.4f} | "
              f"last Top-1robust Accuracy: {prec1_robust.item():.2f}% | "
              f"last Top-1std Accuracy: {prec1_standard.item():.2f}% |"
              f"loss_standard: {loss_standard} |"
              f"loss_adv_weighted: {lambda_epoch *loss_adv_weighted} |"
              f"loss_loss_aux: {loss_aux} |"
              f"loss_alp: {loss_alp} ")


    return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()

    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            input = Variable(input, requires_grad=False).cuda(non_blocking=True)
            target = Variable(target, requires_grad=False).cuda(non_blocking=True)

            logits, _ = model(input)
            # logits = model(input)

            loss = criterion(logits, target)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)
            objs.update(loss.data.item(), n)
            top1.update(prec1.data.item(), n)
            top5.update(prec5.data.item(), n)

    return top1.avg, objs.avg



def adjust_learning_rate(optimizer, epoch):
    base_lr = args.learning_rate
    # Step-wise decay
    if epoch >= 100:
        lr = base_lr * 0.001
    elif epoch >= 80:
        lr = base_lr * 0.01  #0.00
    elif epoch >= 40:
        lr = base_lr * 0.1 #0.01
    else:
        lr = base_lr


    # Apply to all parameter groups
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        # Print current learning rate
    # Get current lr from optimizer to verify it’s applied
    current_lrs = [group['lr'] for group in optimizer.param_groups]
    print(f"[Epoch {epoch}] Learning Rates in Optimizer: {[f'{lr:.6f}' for lr in current_lrs]}")


if __name__ == '__main__':
    main()
