import os
import copy
import time
import math
import utils
import torch.nn as nn

import torch
from tqdm import tqdm
import logging
from torch.autograd import Variable
from evaluate import evaluate, evaluate_kd
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import StepLR, MultiStepLR
from my_loss_function import loss_label_smoothing, loss_kd_regularization, loss_kd, loss_kd_self, loss_pseudo_kd, \
    loss_CE


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


# KD train and evaluate
def train_and_evaluate_kd(model, teacher_model, train_dataloader, val_dataloader, test_dataloader, optimizer,
                          loss_fn_kd, warmup_scheduler, params, args, restore_file=None):
    """
    KD Train the model and evaluate every epoch.
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    # tensorboard setting
    log_dir = args.model_dir + '/tensorboard/'
    model_dir = log_dir
    writer = SummaryWriter(log_dir=log_dir)

    best_val_acc = 0.0
    teacher_model.eval()
    teacher_acc = evaluate_kd(teacher_model, test_dataloader, params)
    print(">>>>>>>>>The teacher accuracy: {}>>>>>>>>>".format(teacher_acc['accuracy']))

    scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
    loss_last = 0
    for epoch in range(params.num_epochs):

        if epoch > 0:  # 0 is the warm up epoch
            scheduler.step()
        logging.info("Epoch {}/{}, lr:{}".format(epoch + 1, params.num_epochs, optimizer.param_groups[0]['lr']))

        # KD Train
        train_acc, train_loss = train_kd(model, teacher_model, optimizer, loss_fn_kd, train_dataloader,
                                         warmup_scheduler, params, args, epoch)
        # Evaluate
        val_metrics = evaluate(model, loss_fn_kd, val_dataloader, params, args, 'val')

        val_acc = val_metrics['accuracy']
        is_best = val_acc >= best_val_acc

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict': optimizer.state_dict()},
                              is_best=is_best,
                              checkpoint=args.model_dir)

        # If best_eval, best_save_path
        val_acc = val_metrics['accuracy']
        is_best = val_acc >= best_val_acc
        val_acc5 = val_metrics['acc@5']
        val_acc1 = val_metrics['acc@1']

        # Save weights
        # utils.save_checkpoint({'epoch': epoch + 1,
        #                        'state_dict': model.state_dict(),
        #                        'optim_dict' : optimizer.state_dict()},
        #                         is_best=is_best,
        #                         checkpoint=model_dir)
        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new val best accuracy")
            best_val_acc = val_acc

            if params.dataset.startswith('cifar'):
                test_metrics = evaluate(model, loss_fn_kd, test_dataloader, params, args, 'test')
                logging.info(f"- New test accuracy: {test_metrics['accuracy']}")

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "eval_best_results.json")
            utils.save_dict_to_json(val_metrics, best_json_path)
            if params.dataset.startswith('cifar'):
                best_json_path = os.path.join(model_dir, "test_best_results.json")
                utils.save_dict_to_json(test_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir, "eval_last_results.json")
        utils.save_dict_to_json(val_metrics, last_json_path)

        # Tensorboard
        writer.add_scalar('Train_accuracy', train_acc, epoch)
        writer.add_scalar('Train_loss', train_loss, epoch)
        writer.add_scalar('Test_accuracy', val_metrics['accuracy'], epoch)
        writer.add_scalar('Test_loss', val_metrics['loss'], epoch)
    if params.dataset.startswith('cifar'):
        last_json_path = os.path.join(model_dir, "test_last_results.json")
        test_metrics = evaluate(model, loss_fn_kd, test_dataloader, params, args, 'test')
        utils.save_dict_to_json(test_metrics, last_json_path)
    writer.close()


# Defining train_kd functions
def train_kd(model, teacher_model, optimizer, loss_fn_kd, dataloader, warmup_scheduler, params, args, epoch, flag=None):
    """
    KD Train the model on `num_steps` batches
    """
    # set model to training mode
    model.train()
    teacher_model.eval()
    loss_avg = utils.RunningAverage()
    losses = utils.AverageMeter()
    total = 0
    correct = 0
    # Use tqdm for progress bar
    with tqdm(total=len(dataloader)) as t:
        for i, (train_batch, labels_batch) in enumerate(dataloader):
            if epoch <= 0:
                warmup_scheduler.step()

            train_batch, labels_batch = train_batch.cuda(), labels_batch.cuda()
            # convert to torch Variables
            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)

            # compute model output, fetch teacher output, and compute KD loss
            output_batch = model(train_batch)

            # get one batch output from teacher model
            output_teacher_batch = teacher_model(train_batch).cuda()
            output_teacher_batch = Variable(output_teacher_batch, requires_grad=False)

            loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch, params)

            # clear previous gradients, compute gradients of all variables wrt loss
            optimizer.zero_grad()
            loss.backward()

            # performs updates using calculated gradients
            optimizer.step()

            _, predicted = output_batch.max(1)
            total += labels_batch.size(0)
            correct += predicted.eq(labels_batch).sum().item()
            # update the average loss
            loss_avg.update(loss.data)
            losses.update(loss.item(), train_batch.size(0))

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()), lr='{:05.6f}'.format(optimizer.param_groups[0]['lr']))
            t.update()

    acc = 100. * correct / total
    logging.info("- Train accuracy: {acc:.4f}, training loss: {loss:.4f}".format(acc=acc, loss=losses.avg))
    return acc, losses.avg


# normal training
def train_and_evaluate(model, train_dataloader, val_dataloader, test_dataloader, optimizer,
                       loss_fn, params, model_dir, warmup_scheduler, args, restore_file=None, teacher=None):
    """
    Train the model and evaluate every epoch.
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    # dir setting, tensorboard events will save in the dirctory
    log_dir = args.model_dir + '/base_train/'
    if args.regularization:
        log_dir = args.model_dir + '/Tf-KD_regularization/'
        model_dir = log_dir
    elif args.pseudo_kd:
        log_dir = args.model_dir + '/Pseudo-KD/'
        model_dir = log_dir
    elif args.pseudo_kd_beta:
        log_dir = args.model_dir + '/Pseudo-KD-beta/'
        model_dir = log_dir
    elif args.beta_ls:
        log_dir = args.model_dir + '/beta_ls/'
        model_dir = log_dir
    elif args.label_smoothing:
        log_dir = args.model_dir + '/label_smoothing/'
        model_dir = log_dir
    writer = SummaryWriter(log_dir=log_dir)

    best_val_acc = 0.0

    # learning rate schedulers
    scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

    for epoch in range(params.num_epochs):
        if epoch > 0:  # 1 is the warm up epoch
            scheduler.step(epoch)
            if epoch in [60, 120, 160]:
                params.lambda_p *= 2
        # if args.pseudo_kd and epoch > 0:
        #     loss = loss_fn
        # elif args.pseudo_kd and epoch == 0:
        #     loss = loss_CE
        # Run one epoch
        logging.info("Epoch {}/{}, lr:{}".format(epoch + 1, params.num_epochs, optimizer.param_groups[0]['lr']))

        if args.pseudo_kd_beta:
            # import ipdb
            # ipdb.set_trace()
            if epoch > 1:
                teacher = copy.deepcopy(model)
                teacher = teacher.cuda()
                teacher.eval()
                print(f'current teacher model from last epoch')
        elif args.beta_ls:
            if epoch == 0:
                teacher = copy.deepcopy(model)
                teacher = teacher.cuda()
                teacher.train()
                print(f'current teacher model from current epoch')
        else:
            teacher = None

        # compute number of batches in one epoch (one full pass over the training set)
        train_acc, train_loss = train(model, optimizer, loss_fn, train_dataloader, params, epoch, warmup_scheduler,
                                      args, teacher)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, params, args, 'val')

        val_acc = val_metrics['accuracy']
        is_best = val_acc >= best_val_acc
        val_acc5 = val_metrics['acc@5']
        val_acc1 = val_metrics['acc@1']

        # Save weights
        # utils.save_checkpoint({'epoch': epoch + 1,
        #                        'state_dict': model.state_dict(),
        #                        'optim_dict' : optimizer.state_dict()},
        #                         is_best=is_best,
        #                         checkpoint=model_dir)
        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new val best accuracy")
            best_val_acc = val_acc

            if params.dataset.startswith('cifar'):
                test_metrics = evaluate(model, loss_fn, test_dataloader, params, args, 'test')
                logging.info(f"- New test accuracy: {test_metrics['accuracy']}")

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "eval_best_results.json")
            utils.save_dict_to_json(val_metrics, best_json_path)
            if params.dataset.startswith('cifar'):
                best_json_path = os.path.join(model_dir, "test_best_results.json")
                utils.save_dict_to_json(test_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir, "eval_last_results.json")
        utils.save_dict_to_json(val_metrics, last_json_path)

        # Tensorboard
        writer.add_scalar('Train_accuracy', train_acc, epoch)
        writer.add_scalar('Train_loss', train_loss, epoch)
        writer.add_scalar('Test_accuracy', val_metrics['accuracy'], epoch)
        writer.add_scalar('Test_loss', val_metrics['loss'], epoch)
    if params.dataset.startswith('cifar'):
        last_json_path = os.path.join(model_dir, "test_last_results.json")
        test_metrics = evaluate(model, loss_fn, test_dataloader, params, args, 'test')
        utils.save_dict_to_json(test_metrics, last_json_path)
    writer.close()


# normal training function
def train(model, optimizer, loss_fn, dataloader, params, epoch, warmup_scheduler, args, teacher_model=None):
    """
    Noraml training, without KD
    """

    # set model to training mode
    model.train()
    loss_avg = utils.RunningAverage()
    losses = utils.AverageMeter()
    total = 0
    correct = 0
    correct1 = 0
    # Use tqdm for progress bar
    global_step = epoch * len(dataloader)
    with tqdm(total=len(dataloader)) as t:
        for i, (train_batch, labels_batch) in enumerate(dataloader):

            train_batch, labels_batch = train_batch.cuda(), labels_batch.cuda()
            if epoch <= 0:
                warmup_scheduler.step()
            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)

            optimizer.zero_grad()
            output_batch = model(train_batch)

            ######

            ######
            if args.regularization:
                loss = loss_fn(output_batch, labels_batch, params)
            elif args.pseudo_kd:
                # if epoch < 2:
                #     loss = loss_CE(output_batch, labels_batch, params)
                # else:
                loss = loss_fn(output_batch, labels_batch, params)
            elif args.pseudo_kd_beta:

                if epoch < 2:
                    loss_fn_1 = loss_CE
                    loss = loss_fn_1(output_batch, labels_batch, params)
                else:
                    # import ipdb
                    # ipdb.set_trace()
                    output_teacher_batch = teacher_model(train_batch).cuda()
                    output_teacher_batch = Variable(output_teacher_batch, requires_grad=False)
                    loss = loss_fn(output_batch, labels_batch, output_teacher_batch, params)
            elif args.beta_ls:
                output_teacher_batch = teacher_model(train_batch).cuda()
                output_teacher_batch = Variable(output_teacher_batch, requires_grad=False)
                loss = loss_fn(output_batch, labels_batch, output_teacher_batch, params)
            elif args.label_smoothing:
                loss = loss_fn(output_batch, labels_batch, params)
            else:
                loss = loss_fn(output_batch, labels_batch)

            ################
            # Gradient Accumulation
            # if epoch == 1 and i == 0:
            #     last_loss = loss.detach().data
            # if epoch >= 1 and i>1:
            #     pass

            ###############
            loss.backward()
            optimizer.step()

            if args.beta_ls:
                # _, predicted = torch.max(output_teacher_batch.data.cpu(), 1)
                # correct1 += (predicted == labels).sum()

                global_step += 1
                alpha_now = min(1 - 1 / (global_step + 1), 0.999)
                for ema_param, param in zip(teacher_model.parameters(), model.parameters()):
                    ema_param.data.mul_(alpha_now).add_(1 - alpha_now, param.data)

            _, predicted = output_batch.max(1)
            total += labels_batch.size(0)
            correct += predicted.eq(labels_batch).sum().item()

            # update the average loss
            loss_avg.update(loss.data)
            losses.update(loss.data, train_batch.size(0))

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()), lr='{:05.6f}'.format(optimizer.param_groups[0]['lr']))
            t.update()

    acc = 100. * correct / total
    logging.info("- Train accuracy: {acc: .4f}, training loss: {loss: .4f}".format(acc=acc, loss=losses.avg))
    return acc, losses.avg









