import os
import time
import argparse
import logging
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.utils
import torch.backends.cudnn as cudnn
import torchvision.datasets as dset
from torch.autograd import Variable

import utils
import ranking_201          # 你的 NB-201 单分支版 ranking_201.py
from model_search_201 import Network
from genotypes_201 import PRIMITIVES, Genotype

parser = argparse.ArgumentParser("cifar")
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
parser.add_argument('--samples', type=int, default=5, help='number of samples for estimation')
parser.add_argument('--data', type=str, default="", help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=200, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.0, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=5, help='total number of layers')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--seed', type=int, default=777, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--save_dir', type=str, default="",
                    help='directory to save checkpoints')

parser.add_argument('--train_size', type=int, default=1000)
parser.add_argument('--valid_size', type=int, default=500)

args = parser.parse_args()

CIFAR_CLASSES = 10

timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[ logging.FileHandler(f'search_{timestamp}.txt') ],
)
os.makedirs(args.save_dir, exist_ok=True)


def main():
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print('gpu device: %d' % args.gpu)
    print('args: %s' % args)

    criterion = nn.CrossEntropyLoss().cuda()
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion).cuda()

    # 仅 normal 的 (edge,op)
    ops = [f'normal_{e}_{o}' for e in range(6) for o in range(len(PRIMITIVES))]
    ops = np.array(ops)

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
    print(len(train_data), len(valid_data))

    num_train = len(train_data)
    indices_train = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    rng = np.random.default_rng(args.seed)
    if args.train_size:
        train_indices = rng.choice(indices_train[:split], size=min(args.train_size, split), replace=False)
    else:
        train_indices = indices_train[:split]
    if args.valid_size:
        valid_indices = rng.choice(indices_train[split:num_train], size=min(args.valid_size, num_train - split), replace=False)
    else:
        valid_indices = indices_train[split:num_train]

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices),
        pin_memory=True, num_workers=2
    )

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_indices),
        pin_memory=True, num_workers=2
    )


    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min
    )


    prev_buf = torch.zeros_like(model.alphas).cuda()

    time1 = time.time()
    warmup = 40

    for epoch in range(args.epochs):
        lr = scheduler.get_last_lr()[0]
        print('epoch: %d, lr: %e' % (epoch, lr))
        logging.info('epoch: %d, lr: %e', epoch, lr)

        train_acc, train_obj = train(train_queue, model, criterion, optimizer, args.grad_clip)
        logging.info('train acc: %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid acc: %f', valid_acc)

        if epoch >= warmup:

            normal_values = ranking_201.compute_value(valid_queue, model, ops, args.samples)
            logging.info(f'normal estimation values:\n{normal_values}')


            delta_alpha, prev_buf = ranking_201.update_alpha(
                normal_values, prev_buf, step_size=0.2, momentum=0.8
            )

            # mean_vals, var_vals = ranking_201.compute_value_mc(valid_queue, model, num_batches=3, num_contexts=4,
            #                                        tau_list=(4.0,2.0,1.0,0.5), eps=0.1, per_node_stratify=True)
            # delta_alpha, prev_buf = ranking_201.update_alpha_uncertainty(mean_vals, var_vals, prev_buf, step_size=0.2, momentum=0.8)


            model.alphas.data.add_(delta_alpha)
            logging.info(f'alpha logits:\n{model.alphas}')


            cur_genotype = ranking_201.ranking(model.get_alphas(), epoch)
            print('genotype for current epoch: ', cur_genotype)
            logging.info('genotype for current epoch: %s', str(cur_genotype))


            ckpt_path = os.path.join(args.save_dir, f'checkpoint_epoch{epoch:03d}.pth')
            torch.save({
                'epoch': epoch,
                'model_state': model.state_dict(),
                'alphas': model.alphas.data.cpu(),
                'optimizer_state': optimizer.state_dict(),
                'scheduler_state': scheduler.state_dict(),
            }, ckpt_path)
            print(f"Saved checkpoint to {ckpt_path}")

        scheduler.step()

    time2 = time.time()
    logging.info(f'total cost {time2 - time1:.2f}s')


def train(train_queue, model, criterion, optimizer, grad_clip=5.0):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()

    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)
        input = Variable(input, requires_grad=False).cuda()
        target = Variable(target, requires_grad=False).cuda(non_blocking=True)

        logits = model(input)
        loss = criterion(logits, target)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)
        objs.update(loss.item(), n)

    logging.info(f'Epoch train avg loss: {objs.avg:.4f}')
    logging.info(f'Epoch train avg top1: {top1.avg:.4f}')
    logging.info(f'Epoch train avg top5: {top5.avg:.4f}')
    return top1.avg, objs.avg


@torch.no_grad()
def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()

    model.eval()
    for step, (input, target) in enumerate(valid_queue):
        input  = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        logits = model(input)
        loss = criterion(logits, target)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)
        objs.update(loss.item(), n)

    logging.info(f'Epoch valid avg loss: {objs.avg:.4f}')
    logging.info(f'Epoch valid avg top1: {top1.avg:.4f}')
    logging.info(f'Epoch valid avg top5: {top5.avg:.4f}')
    return top1.avg, objs.avg


if __name__ == '__main__':
    main()
