import os

import argparse
import time
import random
import numpy as np

import torch
from torch.optim import AdamW, SGD, Optimizer
from torch.utils.data import DataLoader

from tensorboardX import SummaryWriter

from datasets import DATASETS, get_dataset
# from train_utils import AverageMeter, accuracy, log, test
# from train_utils import prologue, seed_everything

from torchvision.utils import save_image

from torch.nn import CrossEntropyLoss

from consistency import consistency_loss

from DRM_sigma_est import DiffusionModel

from transformers import AutoModelForImageClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument("--diffusion_path", type=str, help="path to diffusion model",
                    default="models/diffusion/cifar10_uncond_50M_500K.pt")
parser.add_argument('--base_vit_path', type=str, default="models/vit/base")
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch', default=256, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.00002, type=float,
                    help='initial learning rate', dest='lr')
parser.add_argument('--opt', type=str, choices=['sgd', 'adamw'], default='sgd')
parser.add_argument('--weight-decay', '--wd', default=0.01, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--id', default=None, type=int,
                    help='experiment id, `randint(10000)` if None')
parser.add_argument('--num_noise_vec', default=1, type=int,
                    help="number of noise vectors. `m` in the paper.")

parser.add_argument('--sigma_cand', type=float, nargs='+', default=[0.25, 0.5, 1.0],
                    help='sigma candidates')

parser.add_argument("--checkpoint_freq", default=50, type=int)
parser.add_argument('--timestamp', default="0", type=str,
                    help='timestamp of the run')

parser.add_argument('--trainset_suffix', type=str, default='finetune_vit',
                    help='suffix for dataset name')

parser.add_argument('--sigma_label_path', type=str, default=None,
                    help='Path to a file containing sigma labels for training set')

args = parser.parse_args()

sigma_cand_str = "_".join([f"%.3f" % sigma for sigma in args.sigma_cand])

args.outdir = f"{args.sigma_label_path}/finetune_vit"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main():
    seed = args.id
    seed_everything(seed)

    args.outdir = args.outdir + f"/{args.timestamp}"
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
        os.makedirs(args.outdir + "/checkpoints")
        
    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename, "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc\tremain data")
    writer = SummaryWriter(args.outdir)

    criterion = CrossEntropyLoss().to(device)

    starting_epoch = 0

    # Load latest checkpoint if exists (to handle philly failures)
    model_path = os.path.join(args.outdir, 'finetuned_vit')


    
    train_split = f"train_{sigma_cand_str}_finetune_vit" 
    test_split = f"test_{sigma_cand_str}_finetune_vit"


    train_dataset = get_dataset(args.dataset, train_split, args.sigma_label_path + "predict_train.npy")
    test_dataset = get_dataset(args.dataset, test_split, args.sigma_label_path + "predict_test.npy")

    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
                              num_workers=args.workers)
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
                             num_workers=args.workers)

    denoiser = DiffusionModel(args.diffusion_path)

    model = AutoModelForImageClassification.from_pretrained(f"{args.base_vit_path}", 
                                                            num_labels=10)
    model.to(device)

    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5, weight_decay=0.01)

   

    
    t = []
    for s in args.sigma_cand:
        target_sigma = s * 2
        real_sigma = 0
        t_this = 0
        while real_sigma < target_sigma:
            t_this += 1
            a = denoiser.diffusion.sqrt_alphas_cumprod[t_this]
            b = denoiser.diffusion.sqrt_one_minus_alphas_cumprod[t_this]
            real_sigma = b / a
        t.append(t_this)

    test_loss, test_acc = test(test_loader, denoiser, t, model, criterion, 0, 
                                    device, writer, args.print_freq, 
                                    args.num_noise_vec)
        

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        train_loss, train_acc = train(train_loader, denoiser, t, model, criterion, optimizer, epoch,
                                      device, writer)
        test_loss, test_acc = test(test_loader, denoiser, t, model, criterion, epoch, 
                                    device, writer, args.print_freq, 
                                    args.num_noise_vec)
        after = time.time()

        log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
            epoch, after - before,
            train_loss, train_acc, test_loss, test_acc))

        model.save_pretrained(model_path)


def init_logfile(filename: str, text: str):
    f = open(filename, 'w')
    f.write(text+"\n")
    f.close()

def _chunk_minibatch(batch, num_batches):
    X, y, sigma = batch
    batch_size = len(X) // num_batches
    for i in range(num_batches):
        yield X[i*batch_size : (i+1)*batch_size], y[i*batch_size : (i+1)*batch_size], sigma[i*batch_size : (i+1)*batch_size]


def train(loader: DataLoader, denoiser, t, model: torch.nn.Module, criterion, optimizer: Optimizer,
          epoch: int, device: torch.device, writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()



    # switch to train mode
    model.train()
    denoiser.eval()

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = _chunk_minibatch(batch, args.num_noise_vec)
        for inputs, targets, sigma_labels in mini_batches:
            inputs, targets, sigma_labels = inputs.to(device), targets.to(device), sigma_labels.to(device)
            batch_size = inputs.size(0)

            if args.num_noise_vec > 1:
                targets = targets.repeat(args.num_noise_vec)
                
            inputs = inputs.repeat(args.num_noise_vec, 1, 1, 1)

            # print(sigma_labels)
            if isinstance(t, list):
                t_batch = torch.tensor(t)[sigma_labels.cpu()].tolist()
            #save_image(inputs, f"temp/{i}_original.png")
            imgs, _ = denoiser(inputs, t_batch)

            imgs = torch.nn.functional.interpolate(imgs, (224, 224), mode='bicubic', antialias=True)

            outputs = model(pixel_values=imgs).logits

            predictions = outputs.argmax(dim=1)

            criterion = CrossEntropyLoss(reduction='none').to(device)            
            loss_xent = criterion(outputs, targets)

            loss = loss_xent.mean()


            # measure accuracy and record loss
            acc1 = accuracy(outputs, targets, topk=(1,))
            losses.update(loss.item(), batch_size)
            top1.update(acc1[0].item(), batch_size)
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1))

    if writer:
        writer.add_scalar('loss/train', losses.avg, epoch)
        writer.add_scalar('batch_time', batch_time.avg, epoch)
        writer.add_scalar('accuracy/train@1', top1.avg, epoch)

    return (losses.avg, top1.avg)

def test(loader, denoiser, t, model, criterion, epoch, device, 
        writer=None, print_freq=10, num_noise_vec=1):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to eval mode
    model.eval()

    with torch.no_grad():
        for i, (inputs, targets, sigma_labels) in enumerate(loader):
            # measure data loading time
            data_time.update(time.time() - end)

            inputs, targets, sigma_labels = inputs.to(device), targets.to(device), sigma_labels.to(device)

            batch_size = inputs.size(0)


            inputs = inputs.repeat(num_noise_vec, 1, 1, 1)

            if isinstance(t, list):
                t_batch = torch.tensor(t)[sigma_labels.cpu()].tolist()

            imgs, _ = denoiser(inputs, t_batch)

            imgs = torch.nn.functional.interpolate(imgs, (224, 224), mode='bicubic', antialias=True)

            outputs = model(pixel_values=imgs).logits

            criterion = CrossEntropyLoss(reduction='none').to(device)

            loss = criterion(outputs, targets)

        
            loss = loss.mean()

            # measure accuracy and record loss
            acc1 = accuracy(outputs, targets, topk=(1, ))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1[0].item(), inputs.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.avg:.3f}\t'
                      'Data {data_time.avg:.3f}\t'
                      'Loss {loss.avg:.4f}\t'
                      'Acc@1 {top1.avg:.3f}'.format(
                    i, len(loader), batch_time=batch_time, data_time=data_time,
                    loss=losses, top1=top1, top5=top5))

        if writer:
            writer.add_scalar('loss/test', losses.avg, epoch)
            writer.add_scalar('accuracy/test@1', top1.avg, epoch)

        return (losses.avg, top1.avg)


def seed_everything(seed, strict=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if strict:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def init_logfile(filename: str, text: str):
    f = open(filename, 'w')
    f.write(text+"\n")
    f.close()


def log(filename: str, text: str):
    f = open(filename, 'a')
    f.write(text+"\n")
    f.close()


if __name__ == "__main__":
    main()