import os

import argparse
import time
import random
import numpy as np

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from architectures import ARCHITECTURES, get_architecture
from datasets import DATASETS, get_dataset, get_num_classes

from torchvision.utils import save_image

from torch.nn import CrossEntropyLoss
from torch.optim import AdamW, SGD, Optimizer
from torch.optim.lr_scheduler import StepLR, MultiStepLR, SequentialLR
from torch.utils.data import DataLoader, Subset

from consistency import consistency_loss

from DRM_sigma_est import DiffusionModel

from tensorboardX import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('arch', type=str, choices=ARCHITECTURES)
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch', default=256, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    help='initial learning rate', dest='lr')
parser.add_argument('--lr_step_size', type=int, default=30,
                    help='How often to decrease learning by gamma.')
parser.add_argument('--lr_milestones',type=int, nargs='+', default=[50, 100],
                    help='milestones for MultiStepLR')
parser.add_argument('--lr_drop', type=int, default=1000,
                    help='When to drop the lr near the end of training')
parser.add_argument('--gamma', type=float, default=0.1,
                    help='LR is multiplied by gamma on schedule.')
parser.add_argument('--opt', type=str, choices=['sgd', 'adamw'], default='sgd')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--noise_sd', default=0.0, type=float,
                    help="standard deviation of Gaussian noise for data augmentation")
parser.add_argument('--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--id', default=None, type=int,
                    help='experiment id, `randint(10000)` if None')

parser.add_argument('--num_noise_vec', default=1, type=int,
                    help="number of noise vectors. `m` in the paper.")

parser.add_argument("--diffusion_path", type=str, help="path to diffusion model",
                    default="models/diffusion/cifar10_uncond_50M_500K.pt")

parser.add_argument('--lbd', default=10., type=float)
parser.add_argument('--eta', default=0.5, type=float)

#####################
# Options added by Salman et al. (2019)
parser.add_argument('--resume', action='store_true',
                    help='if true, tries to resume training from existing checkpoint')
parser.add_argument('--pretrained-model', type=str, default='',
                    help='Path to a pretrained model')

parser.add_argument('--sigma_cand', type=float, nargs='+', default=[0.25, 0.5, 1.0],
                    help='sigma candidates')

parser.add_argument('--class_weights', action='store_true')

parser.add_argument("--checkpoint_freq", default=50, type=int)

parser.add_argument('--loss_cl', default='ce', type=str,
                    help='type of classification loss')
parser.add_argument('--loss_con', action='store_true',
                    help='if true, use consistency loss')

parser.add_argument('--sigma_label_dir', type=str, default='data/sigma_label/base',
                    help='suffix for dataset name')

parser.add_argument('--timestamp', default="0", type=str,
                    help='timestamp of the run')

parser.add_argument('--round', default=0, type=int)

args = parser.parse_args()

sigma_cand_str = "_".join([f"%.3f" % sigma for sigma in args.sigma_cand])

args.outdir = f"logs/sigma_est/{args.dataset}"

args.outdir = os.path.join(args.outdir, f"num_{args.num_noise_vec}")

args.outdir = os.path.join(args.outdir, f"{sigma_cand_str}/noise_{args.noise_sd}")

loss_str = f"{args.loss_cl}"
if args.loss_con:
    loss_str += f"_con_lbd{args.lbd}_eta{args.eta}"

args.outdir = os.path.join(args.outdir, loss_str)

if args.class_weights:
    args.outdir = os.path.join(args.outdir, "class_weights")

if args.round > 0:
    args.outdir = os.path.join(args.outdir, f"round{args.round}")


MAX_POSSIBLE_RADIUS = 3.2

def main():
    seed = args.id
    seed_everything(seed)

    args.outdir = args.outdir + f"/{args.arch}/{args.id}/{args.timestamp}"
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
        os.makedirs(args.outdir + "/checkpoints")

    sigma_cand_str = "_".join([f"%.3f" % sigma for sigma in args.sigma_cand])

    sigma_label_path_train = os.path.join(args.sigma_label_dir, f"{sigma_cand_str}_train.npy")
    sigma_label_path_test = os.path.join(args.sigma_label_dir, f"{sigma_cand_str}_test.npy")
    train_dataset = get_dataset(args.dataset, "train_sigma_est", sigma_label_path_train)
    test_dataset = get_dataset(args.dataset, "test_sigma_est", sigma_label_path_test)

    pin_memory = (args.dataset == "imagenet")

    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
                              num_workers=args.workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
                             num_workers=args.workers, pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset, len(args.sigma_cand), False)

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename, "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc\tremain data")
    writer = SummaryWriter(args.outdir)

    params = model.parameters()
    print(args.lr)
    if args.opt == 'sgd':
        optimizer = SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.opt == 'adamw':
        optimizer = AdamW(params, lr=args.lr, weight_decay=args.weight_decay)
    
    milestones = args.lr_milestones
    milestones.append(args.lr_drop)

    starting_epoch = 0

    model_path = os.path.join(args.outdir, 'checkpoint.pth.tar')

    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=args.gamma, last_epoch=starting_epoch - 1)


    if args.class_weights:
        class_weights_path = f"{args.sigma_label_dir}/{sigma_cand_str}_train_class_weights.npy"
        class_weights = np.load(class_weights_path)
        print(class_weights)
    else:
        class_weights = np.ones(len(args.sigma_cand))

    denoiser = DiffusionModel(args.diffusion_path)

    target_sigma = args.noise_sd * 2
    real_sigma = 0
    t = 0
    while real_sigma < target_sigma:
        t += 1
        a = denoiser.diffusion.sqrt_alphas_cumprod[t]
        b = denoiser.diffusion.sqrt_one_minus_alphas_cumprod[t]
        real_sigma = b / a

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        train_loss, train_acc = train(train_loader, denoiser, t, model, optimizer, epoch,
                                      args.noise_sd, device, writer, class_weights)
        test_loss, test_acc = test(test_loader, denoiser, t, model, epoch, args.noise_sd,
                                    device,  args, writer, args.print_freq, args.num_noise_vec)
        after = time.time()

        log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
            epoch, after - before,
            scheduler.get_last_lr()[0], train_loss, train_acc, test_loss, test_acc))

        scheduler.step()

        torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, model_path)

        if (epoch + 1) % args.checkpoint_freq == 0:
            checkpoints_path = model_path.replace('.pth.tar', f's/checkpoint{epoch}.pth.tar')
            torch.save({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, checkpoints_path)


def _chunk_minibatch(batch, num_batches):
    X, y, max_radius, radii = batch
    batch_size = len(X) // num_batches
    for i in range(num_batches):
        yield X[i*batch_size : (i+1)*batch_size], y[i*batch_size : (i+1)*batch_size], \
                max_radius[i*batch_size : (i+1)*batch_size], radii[i*batch_size : (i+1)*batch_size]


def train(loader: DataLoader, denoiser, t, model: torch.nn.Module, optimizer: Optimizer,
          epoch: int, noise_sd: float, device: torch.device, writer=None, class_weights=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()


    class_weights = torch.tensor(class_weights, dtype=torch.float32, device=device)


    # switch to train mode
    model.train()
    denoiser.eval()

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = _chunk_minibatch(batch, args.num_noise_vec)
        for inputs, targets, max_radius, radii in mini_batches:
            inputs, targets, max_radius, radii = inputs.to(device), targets.to(device), max_radius.to(device), radii.to(device)
            batch_size = inputs.size(0)

            if args.num_noise_vec > 1:
                targets = targets.repeat(args.num_noise_vec)
                radii = radii.repeat(args.num_noise_vec, 1)
                max_radius = max_radius.repeat(args.num_noise_vec)
                inputs = inputs.repeat(args.num_noise_vec, 1, 1, 1)

            imgs, inputs_noisy = denoiser(inputs, t)

            outputs = model(imgs)
            
            logits_chunk = torch.chunk(outputs, args.num_noise_vec, dim=0)

            if args.class_weights:
                weights = class_weights[targets]
            else:
                weights = torch.ones_like(targets, dtype=torch.float32, device=device)

            predictions = outputs.argmax(dim=1)
            radii_pred = radii[torch.arange(outputs.shape[0]), predictions]

            loss_xent_per_sigma = []
            loss_soft_xent_per_sigma = []
            loss_con_per_sigma = []
            sigma_num = []

            

            if args.loss_cl == 'ce':
                criterion_cl = CrossEntropyLoss(reduction='none')
                loss_cl = criterion_cl(outputs, targets)
            elif args.loss_cl == 'softce':
                loss_cl = softXEnt(outputs, radii, 'soft_max')

            loss_cl = loss_cl * weights

            if not args.loss_con:
                loss = loss_cl.mean()
            else:
                loss_con, loss_kl_mean, loss_ent_mean = consistency_loss(logits_chunk, args.lbd, args.eta)
                con_loss_weights = max_radius[:batch_size] / (MAX_POSSIBLE_RADIUS * max(args.sigma_cand))
                loss_con = loss_con * con_loss_weights * weights[0:batch_size]

                loss = loss_cl.mean() + loss_con.mean()
            
            # measure accuracy and record loss
            acc1 = accuracy(outputs, targets, topk=(1,))
            losses.update(loss.item(), batch_size)
            top1.update(acc1[0].item(), batch_size)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1))

    if writer:
        writer.add_scalar('loss/train', losses.avg, epoch)
        writer.add_scalar('batch_time', batch_time.avg, epoch)
        writer.add_scalar('accuracy/train@1', top1.avg, epoch)


    return (losses.avg, top1.avg)

def softXEnt(output, target, type='linear'):
    eps = 1e-6
    target = target.clamp(min=eps)

    if type == 'linear':
        target = target / target.sum(dim=1, keepdim=True)
    elif type == 'soft_max':
        target = torch.nn.functional.softmax(target, dim=1)

    logprobs = torch.nn.functional.log_softmax(output, dim=1)
    return -(target * logprobs).sum(dim=1)


def test(loader, denoiser, t, model, epoch, noise_sd, device, args,
        writer=None, print_freq=10, num_noise_vec=1):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    model.eval()

    with torch.no_grad():
        for i, (inputs, targets, max_radius, radii) in enumerate(loader):

            batch_size = inputs.size(0)
            data_time.update(time.time() - end)

            inputs, targets, max_radius, radii = inputs.to(device), targets.to(device), max_radius.to(device), radii.to(device)
            
            batch_size = inputs.size(0)

            
            inputs = inputs.repeat(num_noise_vec, 1, 1, 1)
            targets = targets.repeat(num_noise_vec)
            max_radius = max_radius.repeat(num_noise_vec)

            imgs, inputs_noisy = denoiser(inputs, t)
            outputs = model(imgs)

            criterion = CrossEntropyLoss(reduction='none').to(device)

            loss_xent = criterion(outputs, targets)

            loss = loss_xent[0:batch_size].mean()

            acc1 = accuracy(outputs[0:batch_size], targets[0:batch_size], topk=(1, ))
            losses.update(loss.item(), inputs.size(0))

            top1.update(acc1[0].item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.avg:.3f}\t'
                      'Data {data_time.avg:.3f}\t'
                      'Loss {loss.avg:.4f}\t'
                      'Acc@1 {top1.avg:.3f}'.format(
                    i, len(loader), batch_time=batch_time, data_time=data_time,
                    loss=losses, top1=top1, top5=top5))

        if writer:
            writer.add_scalar('loss/test', losses.avg, epoch)
            writer.add_scalar('accuracy/test@1', top1.avg, epoch)

        return (losses.avg, top1.avg)


def seed_everything(seed, strict=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if strict:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def init_logfile(filename: str, text: str):
    f = open(filename, 'w')
    f.write(text+"\n")
    f.close()


def log(filename: str, text: str):
    f = open(filename, 'a')
    f.write(text+"\n")
    f.close()



if __name__ == "__main__":
    main()
