import time
import numpy as np
import torch
import utils
import argparse
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import ranking
import logging
from torch.autograd import Variable
from model_search import Network

from datetime import datetime
import os

from utils import madry_generate

parser = argparse.ArgumentParser("cifar")
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
parser.add_argument('--samples', type=int, default=10, help='number of samples for estimation')
parser.add_argument('--data', type=str, default="", help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=80, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.0, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--init_channels', type=int, default=32, help='num of init channels')
parser.add_argument('--layers', type=int, default=10, help='total number of layers')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--epsilon', type=float, default=8/255, help='perturbation')
parser.add_argument('--num_steps', type=int, default=7, help='perturb number of steps')
parser.add_argument('--step_size', type=float, default=2/255, help='perturb step size')
parser.add_argument('--save_dir', type=str, default="",
                    help='directory to save checkpoints')
parser.add_argument('--train_size', type=int, default=1000)
parser.add_argument('--valid_size', type=int, default=500)
args = parser.parse_args()


CIFAR_CLASSES = 10



timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f'search_{timestamp}.txt'),
    ]
)

os.makedirs(args.save_dir, exist_ok=True)




def main():
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print('gpu device: %d' % args.gpu)
    print('args: %s' % args)


    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
    model = model.cuda()

    ops = []
    for cell_type in ['normal', 'reduce', 'robust']:
        for edge in range(model.num_edges):
            ops.append(['{}_{}_{}'.format(cell_type, edge, i) for i in range(0, model.num_ops)])
    ops = np.concatenate(ops)

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)



    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
    print(len(train_data), len(valid_data))
    #train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
    num_train = len(train_data)
    indices_train = list(range(num_train))

    split = int(np.floor(args.train_portion * num_train))


    rng = np.random.default_rng(args.seed)
    train_indices_small = rng.choice(indices_train[:split], size=min(args.train_size, split), replace=False)
    valid_indices_small = rng.choice(indices_train[split:num_train], size=min(args.valid_size, num_train - split),
                                     replace=False)
    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices_small),
        pin_memory=True)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_indices_small),
        pin_memory=True)


    # train_queue = torch.utils.data.DataLoader(
    #     train_data, batch_size=args.batch_size,
    #     sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train[:split]),
    #     pin_memory=True)
    #
    # valid_queue = torch.utils.data.DataLoader(
    #     train_data, batch_size=args.batch_size,
    #     sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train[split:num_train]),
    #     pin_memory=True)


    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    prev_values = [1e-3 * torch.randn(model.num_edges, model.num_ops).cuda(),
                   1e-3 * torch.randn(model.num_edges, model.num_ops).cuda(),
                   1e-3 * torch.randn(model.num_edges, model.num_ops).cuda()]
    # for alpha, saved in zip(model.arch_parameters, ckpt.get('arch_parameters', [])):
    #     alpha.data.copy_(saved.cuda())
    # start_epoch = ckpt['epoch'] + 1
    # print(f"Checkpoint loaded, resume from epoch {start_epoch}")

    time1 = time.time()
    warmup = 40

    # model.load_state_dict(ckpt['model_state'])
    # optimizer.load_state_dict(ckpt['optimizer_state'])
    # scheduler.load_state_dict(ckpt['scheduler_state'])

    # for epoch in range(start_epoch, args.epochs):
    for epoch in range(args.epochs):
        lr = scheduler.get_last_lr()
        print('epoch: %d, lr: %e' % (epoch, lr[0]))
        logging.info('epoch: %d, lr: %e' % (epoch, lr[0]))
        train_acc, train_obj, train_robust_acc = train(train_queue, model, criterion, optimizer, args)


        logging.info('train acc: %f' % train_acc)
        valid_acc, valid_obj, valid_robust_acc = infer(valid_queue, model, criterion, args)
        logging.info('valid acc: %f' % valid_acc)
        if epoch >= warmup:
            normal_values, reduce_values, robust_values = ranking.compute_value(valid_queue, model, ops, args.samples)
            print('normal estimation values: ', normal_values)
            print('reduce estimation values: ', reduce_values)
            print('robust estimation values: ', robust_values)
            logging.info(f'normal estimation values: {normal_values}')
            logging.info(f'reduce estimation values:   {reduce_values} ')
            logging.info(f'robust estimation values: {robust_values}')



            prev_values = ranking.update_alpha([normal_values, reduce_values, robust_values], prev_values)
            model.arch_parameters[0] += prev_values[0]
            model.arch_parameters[1] += prev_values[1]
            model.arch_parameters[2] += prev_values[2]
            logging.info(f'normal cell alpha:   {model.arch_parameters[0]} ')
            logging.info(f'reduction cell alpha:   {model.arch_parameters[1]} ')
            logging.info(f'robust cell alpha:   {model.arch_parameters[2]} ')


            for alpha in model.arch_parameters:
                alpha.data.div_(alpha.data.sum(dim=1, keepdim=True) + 1e-9)


            logging.info(f'normal cell alpha:   {model.arch_parameters[0]} ')
            logging.info(f'reduction cell alpha:   {model.arch_parameters[1]} ')
            logging.info(f'robust cell alpha:   {model.arch_parameters[2]} ')


            cur_genotype = ranking.ranking(
                model.arch_parameters[0],
                model.arch_parameters[1],
                model.arch_parameters[2],
                epoch
            )
            ckpt_path = os.path.join(args.save_dir, f'checkpoint_epoch{epoch:03d}.pth')

            torch.save({
                'epoch': epoch,
                'model_state': model.state_dict(),
                'arch_parameters': [alpha.data.cpu() for alpha in model.arch_parameters],
                'optimizer_state': optimizer.state_dict(),
                'scheduler_state': scheduler.state_dict(),
            }, ckpt_path)
            print(f"Saved checkpoint to {ckpt_path}")
            print('genotype for current epoch: ', cur_genotype)
            logging.info('genotype for current epoch: %s' % str(cur_genotype))

        scheduler.step()
    time2 = time.time()
    logging.info(f'total cost {time2-time1}')

def train(train_queue, model, criterion, optimizer, args, lambda_robust=0.5):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    robust_top1 = utils.AvgrageMeter()  # For robust accuracy
    robust_top5 = utils.AvgrageMeter()  # For robust accuracy
    standard_losses = utils.AvgrageMeter()  # For tracking standard loss
    robust_losses = utils.AvgrageMeter()  # For tracking robust loss
    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)

        input = Variable(input, requires_grad=False).cuda()
        target = Variable(target, requires_grad=False).cuda(non_blocking=True)

        # Generate adversarial example (PGD)
        input_pgd = madry_generate(model, input, target, optimizer,
                               step_size=args.step_size, epsilon=args.epsilon,
                               perturb_steps=args.num_steps)

        logits_standard = model(input)
        loss_standard = criterion(logits_standard, target)

        # Adversarial loss: loss on the adversarial data
        logits_pgd = model(input_pgd)
        loss_pgd = criterion(logits_pgd, target)

        # Combine the standard loss and the adversarial loss (weighted by lambda_robust)
        loss = (1 - lambda_robust) * loss_standard + lambda_robust * loss_pgd

        optimizer.zero_grad()
        loss.backward()




        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        # Compute standard accuracy (on original data)
        prec1_standard, prec5_standard = utils.accuracy(logits_standard, target, topk=(1, 5))

        # Compute robust accuracy (on adversarial data)
        prec1_robust, prec5_robust = utils.accuracy(logits_pgd, target, topk=(1, 5))

        # Update standard metrics
        top1.update(prec1_standard.data.item(), n)
        top5.update(prec5_standard.data.item(), n)

        # Update robust metrics
        robust_top1.update(prec1_robust.data.item(), n)
        robust_top5.update(prec5_robust.data.item(), n)

        # Update loss meters
        standard_losses.update(loss_standard.data.item(), n)
        robust_losses.update(loss_pgd.data.item(), n)
        # Update loss
        objs.update(loss.data.item(), n)



    logging.info(f'Epoch training average loss: {objs.avg:.4f}')
    logging.info(f'Epoch training average standard accuracy: {top1.avg:.4f}')
    logging.info(f'Epoch training average robust accuracy: {robust_top1.avg:.4f}')
    logging.info(f'Epoch training average standard TOP5 accuracy: {top5.avg:.4f}')
    logging.info(f'Epoch training average robust TOP5 accuracy: {robust_top5.avg:.4f}')
    logging.info(f'Epoch training average standard loss: {standard_losses.avg:.4f}')
    logging.info(f'Epoch training average robust loss: {robust_losses.avg:.4f}')
    return top1.avg, objs.avg, robust_top1.avg


def infer(valid_queue, model, criterion, args):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    robust_top1 = utils.AvgrageMeter()  # For robust accuracy
    robust_top5 = utils.AvgrageMeter()  # For robust accuracy
    standard_losses = utils.AvgrageMeter()  # For tracking standard loss
    robust_losses = utils.AvgrageMeter()  # For tracking robust loss

    model.eval()

    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            input = Variable(input, requires_grad=False).cuda(non_blocking=True)
            target = Variable(target, requires_grad=False).cuda(non_blocking=True)

            # Standard loss (on original data)
            logits_standard = model(input)
            loss_standard = criterion(logits_standard, target)

            # Generate adversarial example (PGD) for robust evaluation
            input_pgd = madry_generate(model, input, target, None,  # We don't need optimizer here during inference
                                       step_size=args.step_size, epsilon=args.epsilon,
                                       perturb_steps=args.num_steps)

            # Adversarial loss (on the adversarial data)
            logits_pgd = model(input_pgd)
            loss_pgd = criterion(logits_pgd, target)

            # Compute accuracy for standard data
            prec1_standard, prec5_standard = utils.accuracy(logits_standard, target, topk=(1, 5))

            # Compute accuracy for adversarial data
            prec1_robust, prec5_robust = utils.accuracy(logits_pgd, target, topk=(1, 5))

            n = input.size(0)

            # Update standard metrics
            top1.update(prec1_standard.data.item(), n)
            top5.update(prec5_standard.data.item(), n)

            # Update robust metrics
            robust_top1.update(prec1_robust.data.item(), n)
            robust_top5.update(prec5_robust.data.item(), n)

            # Update loss meters
            standard_losses.update(loss_standard.data.item(), n)
            robust_losses.update(loss_pgd.data.item(), n)

            # Update total loss
            objs.update(loss_standard.data.item(), n)

    logging.info(f'Epoch validation average loss: {objs.avg:.4f}')
    logging.info(f'Epoch validation average standard accuracy: {top1.avg:.4f}')
    logging.info(f'Epoch validation average robust accuracy: {robust_top1.avg:.4f}')
    logging.info(f'Epoch validation average standard TOP5 accuracy: {top5.avg:.4f}')
    logging.info(f'Epoch validation average robust TOP5 accuracy: {robust_top5.avg:.4f}')
    logging.info(f'Epoch validation average standard loss: {standard_losses.avg:.4f}')
    logging.info(f'Epoch validation average robust loss: {robust_losses.avg:.4f}')

    return top1.avg, objs.avg, robust_top1.avg


if __name__ == '__main__':
    main()
