# from https://github.com/myungsub/CAIN/blob/master/utils.py, 
# but removed the errenous normalization and quantization steps from computing the PSNR.

from pytorch_msssim import ssim_matlab as calc_ssim
import math
import os
import torch
import shutil
import numpy as np
import random
import torch.distributed as dist
from torchvision import transforms
import torch.nn.functional as F


def set_seed(seed=None, cuda=False): # 此处修改想要复现的模型的seed
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    # torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.deterministic = True 
    # torch.backends.cudnn.enabled = False # 因为关闭了cudnn，以上代码会影响一定的效率


def init_meters(loss_str):
    losses = init_losses(loss_str)
    psnrs = AverageMeter()
    ssims = AverageMeter()
    return losses, psnrs, ssims

def eval_metrics(output, gt, psnrs, ssims, num_gpu=None):
    # PSNR should be calculated for each image, since sum(log) =/= log(sum).
    for b in range(gt.size(0)):
        psnr = calc_psnr(output[b], gt[b])
        if num_gpu is not None:
            dist.barrier()
            psnr = reduce_tensor(psnr, num_gpu)
        psnrs.update(psnr)

        ssim = calc_ssim(output[b].unsqueeze(0).clamp(0,1), gt[b].unsqueeze(0).clamp(0,1) , val_range=1.)
        if num_gpu is not None:
            dist.barrier()
            ssim = reduce_tensor(ssim, num_gpu)
        ssims.update(ssim)

def init_losses(loss_str):
    loss_specifics = {}
    loss_list = loss_str.split('+')
    for l in loss_list:
        _, loss_type = l.split('*')
        loss_specifics[loss_type] = AverageMeter()
    loss_specifics['total'] = AverageMeter()
    return loss_specifics

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def calc_psnr(pred, gt):
    diff = (pred - gt).pow(2).mean() + 1e-8
    psnr =  -10 * torch.log10(diff)
    return psnr

# def calc_psnr(pred, gt): #网上找的，还不能用，待修改
#     mse = torch.mean( (pred / 1.0 - gt / 1.0) ** 2 )
#     if mse == 0:
#         return 100
#     PIXEL_MAX = 255.0
#     return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

def save_checkpoint(state, directory, is_best, save_rank, filename='checkpoint.pth' ):
    """Saves checkpoint to disk"""
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = os.path.join(directory , filename)
    if torch.distributed.get_rank() == save_rank:
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename, os.path.join(directory , 'model_best.pth'))
    

def log_tensorboard(writer, loss, psnr, ssim, lpips, lr, timestep, mode='train'):
    writer.add_scalar('Loss/%s/%s' % mode, loss, timestep)
    writer.add_scalar('PSNR/%s' % mode, psnr, timestep)
    writer.add_scalar('SSIM/%s' % mode, ssim, timestep)
    if mode == 'train':
        writer.add_scalar('lr', lr, timestep)
        
def reduce_tensor(tensor, world_size):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= world_size
    return rt