# Thank the authors of pytorch-generative-model-collections and examples of pytorch.
# The github address is https://github.com/znxlwm/pytorch-generative-model-collections
# and https://github.com/pytorch/examples/blob/master/mnist/main.py respectively.
# Some parts of code are adapted from their repositories.

import os, torch
import torch.nn as nn
import numpy as np
import random
import math
import torch.backends.cudnn as cudnn

class AverageMeter(object):
    """computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def set_seed(seed):
    cudnn.deterministic = True
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)

def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

def get_optimizer(model, args, lr):
    if args.optimizer_name.lower() == "adam":
        optim = torch.optim.Adam(model.parameters(), lr)
    elif args.optimizer_name.lower() == "nesterov":
        optim = torch.optim.SGD(
            model.parameters(), lr, weight_decay=args.weight_decay, nesterov=True
        )
    elif args.optimizer_name.lower() == "sgd":
        optim = torch.optim.SGD(
            model.parameters(), lr, momentum = args.momentum, weight_decay=args.weight_decay,
        )
    return optim

def adjust_learning_rate(args, optimizer, epoch, epochs):
    lr = args.lr
    if args.lr_cosine_decay:
        eta_min = lr * (args.lr_decay_rate ** 2)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def initialize_average_meters():
    average_meters = {
        'sup_con_loss' : AverageMeter(),
        'ird_loss' : AverageMeter(),
        'unsup_con_loss' : AverageMeter(),
        'total' : AverageMeter()
    }

    return average_meters

def update_average_meters(average_meters, labeled_losses):
    for loss_name, loss_value in labeled_losses.items():
        average_meters[loss_name].update(loss_value)


def update_tensorboard(writer, average_meters, optimizer, task_number, epoch, key_name):
        for loss_name, loss_average_meter in average_meters.items():
            writer.add_scalar(key_name + str(task_number + 1) + '/' + loss_name, loss_average_meter.avg, epoch)
            loss_average_meter.reset()

        writer.add_scalar(key_name + str(task_number + 1) + '/lr', optimizer.param_groups[0]['lr'], epoch)


def save_model(model, indexes, directory, task_number):
    model_path = os.path.join(directory, "encoder_{}".format(task_number + 1))
    index_path = os.path.join(directory, "index_{}.npy".format(task_number + 1))
    torch.save(model.encoder.state_dict(), model_path)
    np.save(index_path, np.array(indexes))