
import torch
import torch.nn as nn
import numpy as np
import scipy.linalg as linalg
from scipy.stats import entropy
import skimage.io as io
import cv2

from skimage.metrics import structural_similarity,peak_signal_noise_ratio

from inception import InceptionV3
import pandas as pd
import math
import sys
sys.path.append('./PerceptualSimilarity')
import lpips

class LPIPS_calcuator():
  def __init__(self):
    self.model= lpips.LPIPS(net='alex',version='0.1').cuda()
    
  def __call__(self,imgs1,imgs2):
    #imgs:n,c,h,w, [-1,1]
    imgs1=torch.from_numpy(imgs1).cuda()
    imgs2=torch.from_numpy(imgs2).cuda()
    dist=self.model.forward(imgs1, imgs2)
    return dist.cpu().numpy()
    
    
  
  
from torchvision.models.inception import inception_v3  


class IS():
  def __init__(self):
    self.model=inception_v3(pretrained=True, transform_input=False).cuda()
    
  def __call__(self,imgs, batch_size=16, splits=10,get_std=False):
    
    #Computes the inception score of the generated images imgs
    #imgs -- Torch dataset of (Nx3xHxW) numpy images normalized in the range [0, 1]
    #batch_size -- batch size for feeding into Inception v3
    #splits -- number of splits
    
    
    N =imgs.shape[0]

    assert batch_size > 0
    assert N > batch_size

    
    # Load inception model
    self.model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear')
    def get_pred(x):
        x = up(x)
        x[:, 0] = x[:, 0] * 0.229  + 0.485
        x[:, 1] = x[:, 1] * 0.224 + 0.456 
        x[:, 2] = x[:, 2] * 0.225 + 0.406 

        x = self.model(x)
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i in range(math.ceil(len(imgs)/batch_size)):
        batch=imgs[i*batch_size:(i+1)*batch_size]
        batch = torch.from_numpy(batch).cuda()*2-1 #[-1,1]
        batch_size_i = len(batch)

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batch)

    # Now compute the mean kl-div
    split_scores = []

    for i in range(splits):
          part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
          kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
          kl = np.mean(np.sum(kl, 1))
          split_scores.append(np.exp(kl))
    if get_std:
        return np.mean(split_scores), np.std(split_scores)
    else:
        return np.mean(split_scores)#, np.std(split_scores)



class FID():
    """docstring for FID
    Calculates the Frechet Inception Distance (FID) to evalulate GANs
    The FID metric calculates the distance between two distributions of images.
    Typically, we have summary statistics (mean & covariance matrix) of one
    of these distributions, while the 2nd distribution is given by a GAN.
    When run as a stand-alone program, it compares the distribution of
    images that are stored as PNG/JPEG at a specified location with a
    distribution given by summary statistics (in pickle format).
    The FID is calculated by assuming that X_1 and X_2 are the activations of
    the pool_3 layer of the inception net for generated samples and real world
    samples respectivly.
    See --help to see further details.
    Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
    of Tensorflow
    Copyright 2018 Institute of Bioinformatics, JKU Linz
    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at
       http://www.apache.org/licenses/LICENSE-2.0
    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
    """
    def __init__(self):
        self.dims = 2048
        self.batch_size = 16
        self.cuda = True
        self.verbose=False

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
        self.model = InceptionV3([block_idx])
        if self.cuda:
            # TODO: put model into specific GPU
            self.model.cuda()

    def __call__(self, images, gt_path):
        """ images:  list of the generated image. The values must lie between 0 and 1.
            gt_path: the path of the ground truth images.  The values must lie between 0 and 1.
        """
        #if not os.path.exists(gt_path):
        #    raise RuntimeError('Invalid path: %s' % gt_path)

        print('calculate gt_path statistics...')
        #m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
        m1, s1 = self.calculate_activation_statistics(gt_path, self.verbose)
        print('calculate generated_images statistics...')
        m2, s2 = self.calculate_activation_statistics(images, self.verbose)
        fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
        return fid_value

    def calculate_from_disk(self, generated_path, gt_path):
        """ 
        """
        if not os.path.exists(gt_path):
            raise RuntimeError('Invalid path: %s' % gt_path)
        if not os.path.exists(generated_path):
            raise RuntimeError('Invalid path: %s' % generated_path)

        print('calculate gt_path statistics...')
        m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
        print('calculate generated_path statistics...')
        m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose)
        print('calculate frechet distance...')
        fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
        print('fid_distance %f' % (fid_value))
        return fid_value        

    def compute_statistics_of_path(self, path, verbose):
        npz_file = os.path.join(path, 'statistics.npz')
        if os.path.exists(npz_file):
            f = np.load(npz_file)
            m, s = f['mu'][:], f['sigma'][:]
            f.close()
        else:
            path = pathlib.Path(path)
            files = list(path.glob('*.jpg')) + list(path.glob('*.png'))

            imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])

            # Bring images to shape (B, 3, H, W)
            imgs = imgs.transpose((0, 3, 1, 2))

            # Rescale images to be between 0 and 1
            imgs /= 255

            m, s = self.calculate_activation_statistics(imgs, verbose)
            np.savez(npz_file, mu=m, sigma=s)

        return m, s    

    def calculate_activation_statistics(self, images, verbose):
        """Calculation of the statistics used by the FID.
        Params:
        -- images      : Numpy array of dimension (n_images, 3, hi, wi). The values
                         must lie between 0 and 1.
        -- model       : Instance of inception model
        -- batch_size  : The images numpy array is split into batches with
                         batch size batch_size. A reasonable batch size
                         depends on the hardware.
        -- dims        : Dimensionality of features returned by Inception
        -- cuda        : If set to True, use GPU
        -- verbose     : If set to True and parameter out_step is given, the
                         number of calculated batches is reported.
        Returns:
        -- mu    : The mean over samples of the activations of the pool_3 layer of
                   the inception model.
        -- sigma : The covariance matrix of the activations of the pool_3 layer of
                   the inception model.
        """
        act = self.get_activations(images, verbose)
        mu = np.mean(act, axis=0)
        sigma = np.cov(act, rowvar=False)
        return mu, sigma            

    def get_activations(self, images, verbose=False):
        """Calculates the activations of the pool_3 layer for all images.
        Params:
        -- images      : Numpy array of dimension (n_images, 3, hi, wi). The values
                         must lie between 0 and 1.
        -- model       : Instance of inception model
        -- batch_size  : the images numpy array is split into batches with
                         batch size batch_size. A reasonable batch size depends
                         on the hardware.
        -- dims        : Dimensionality of features returned by Inception
        -- cuda        : If set to True, use GPU
        -- verbose     : If set to True and parameter out_step is given, the number
                         of calculated batches is reported.
        Returns:
        -- A numpy array of dimension (num images, dims) that contains the
           activations of the given tensor when feeding inception with the
           query tensor.
        """
        self.model.eval()

        d0 = images.shape[0]
        if self.batch_size > d0:
            print(('Warning: batch size is bigger than the data size. '
                   'Setting batch size to data size'))
            self.batch_size = d0

        n_batches = d0 // self.batch_size
        n_used_imgs = n_batches * self.batch_size

        pred_arr = np.empty((n_used_imgs, self.dims))
        for i in range(n_batches):
            if verbose:
                print('\rPropagating batch %d/%d' % (i + 1, n_batches))
            start = i * self.batch_size
            end = start + self.batch_size

            batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
            if self.cuda:
                batch = batch.cuda()

            pred = self.model(batch)[0]

            # If model output is not scalar, apply global spatial average pooling.
            # This happens if you choose a dimensionality not equal 2048.
            if pred.shape[2] != 1 or pred.shape[3] != 1:
                pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

            pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1)

        if verbose:
            print(' done')

        return pred_arr

    def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
        """Numpy implementation of the Frechet Distance.
        The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
        and X_2 ~ N(mu_2, C_2) is
                d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
        Stable version by Dougal J. Sutherland.
        Params:
        -- mu1   : Numpy array containing the activations of a layer of the
                   inception net (like returned by the function 'get_predictions')
                   for generated samples.
        -- mu2   : The sample mean over activations, precalculated on an 
                   representive data set.
        -- sigma1: The covariance matrix over activations for generated samples.
        -- sigma2: The covariance matrix over activations, precalculated on an 
                   representive data set.
        Returns:
        --   : The Frechet Distance.
        """

        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)

        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'

        diff = mu1 - mu2

        # Product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                   'adding %s to diagonal of cov estimates') % eps
            print(msg)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real

        tr_covmean = np.trace(covmean)

        return (diff.dot(diff) + np.trace(sigma1) +
                np.trace(sigma2) - 2 * tr_covmean)
                
def compare_ssim(img1,img2,get_map=False):
       #img1:m,n,3;(0,1)
       if get_map:
           ssim,s=structural_similarity(img1,img2, data_range=255, multichannel=True,full=True)
           return ssim,s
       else:
           ssim=structural_similarity(img1,img2, data_range=255, multichannel=True)
           return ssim
           
def compare_psnr(img1,img2):
      psnr=peak_signal_noise_ratio(img1,img2,data_range=255)
      return psnr

def labelcolormap(N):
    if N == 20: # CelebAMask-HQ
        cmap = np.array([(0,  0,  0), (204, 0,  0), (76, 153, 0),
                     (204, 204, 0), (51, 51, 255), (204, 0, 204), (0, 255, 255),
                     (51, 255, 255), (102, 51, 0), (255, 0, 0), (102, 204, 0),
                     (255, 255, 0), (0, 0, 153), (0, 0, 204), (255, 51, 153), 
                     (0, 204, 204), (0, 51, 0), (255, 153, 51), (0, 204, 0), (255, 0, 255)], 
                     dtype=np.uint8) 
    else:
        cmap = np.zeros((N, 3), dtype=np.uint8)
        for i in range(N):
            r, g, b = 0, 0, 0
            id = i
            for j in range(7):
                str_id = uint82bin(id)
                r = r ^ (np.uint8(str_id[-1]) << (7-j))
                g = g ^ (np.uint8(str_id[-2]) << (7-j))
                b = b ^ (np.uint8(str_id[-3]) << (7-j))
                id = id >> 3
            cmap[i, 0] = r
            cmap[i, 1] = g
            cmap[i, 2] = b
    return cmap

from numpy.random import random_sample    
class TPSWarp(object):
    """
    TPS param for non-linear warping:
    nonlinear_pert_range: [-2, 2] (random perturbation of x and y by +/- 2 pixels
    TPS params for affine transformation
    defaults: rotation +/- pi/4
    scales between 0.9 and 1.1 factor
    translates between +/-5 pixels
    """
    def __init__(self, image_size, margin, num_vertical_points, num_horizontal_points,
                 nonlinear_pert_range=[-2, 2],
                 rot_range=[-np.pi/8, np.pi/8],
                 scale_range=[1.05, 1.15],
                 trans_range=[-10, 10], append_offset_channels=False):

        self.nonlinear_pert_range = nonlinear_pert_range
        self.rot_range = rot_range
        self.scale_range = scale_range
        self.trans_range = trans_range
        self.num_points = num_horizontal_points*num_vertical_points
        self.append_offset_channels = append_offset_channels
        horizontal_points = np.linspace(margin, image_size[0] - margin, num_horizontal_points)
        vertical_points = np.linspace(margin, image_size[1] - margin, num_vertical_points)
        xv, yv = np.meshgrid(horizontal_points, vertical_points, indexing='xy')
        xv = xv.reshape(1, -1, 1)
        yv = yv.reshape(1, -1, 1)
        self.grid = np.concatenate((xv, yv), axis=2)
        self.matches = list()

        # TPS define the alignment between source and target grid points
        # here, we just assume nth source keypoint aligns to nth target keypoint
        for i in range(self.num_points):
            self.matches.append(cv2.DMatch(i, i, 0))

    def sample_warp(self):
        """samples the warping matrix based on initialized parameters
        """

        # will be on the right side of the multiply, e.g ([x,y] * w
        rot = random_sample() * (self.rot_range[1] - self.rot_range[0]) + self.rot_range[0]
        sc_x = random_sample() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
        sc_y = random_sample() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
        t_x = random_sample() * (self.trans_range[1] - self.trans_range[0]) + self.trans_range[0]
        t_y = random_sample() * (self.trans_range[1] - self.trans_range[0]) + self.trans_range[0]
        # return a transposed matrix
        rotscale = [[ sc_x*np.cos(rot), -np.sin(rot)],
               [ np.sin(rot),  sc_y*np.cos(rot)]]
        return rotscale, t_x, t_y

    def random_perturb(self):
        """Returns a matrix for individually perturbing each grid point
        """
        perturb_mat = random_sample(self.grid.shape) * (self.nonlinear_pert_range[1]
                                                        - self.nonlinear_pert_range[0]) + self.nonlinear_pert_range[0]
        return perturb_mat

    def __call__(self, img, tps=None):
        """
        accepts a PIL image
        must convert to numpy array to apply TPS
        converts back to PIL image before returning
        """

        # construct the transformed grid from the regular grid
        img_as_arr = np.transpose(img.numpy(), (1, 2, 0))
        if tps is None:
            warp_matrix, t_x, t_y = self.sample_warp()
            perturb_mat = self.random_perturb()
            center = np.array([[[self.grid[:, :, 0].max()/2.0 + t_x, self.grid[:, :, 1].max()/2.0 + t_y]]])

            target_grid = np.matmul((self.grid - center), warp_matrix) + perturb_mat + center
            tps = cv2.createThinPlateSplineShapeTransformer()
            tps.estimateTransformation(self.grid, target_grid, self.matches)
        img_as_arr = tps.warpImage(img_as_arr, borderMode=cv2.BORDER_REPLICATE)
        dims = img_as_arr.shape

        if self.append_offset_channels:  # extract ground truth warping offsets
            full_grid_x, full_grid_y = np.meshgrid(np.arange(dims[1]), np.arange(dims[0]))
            dims_half_x = dims[1]/2.0
            dims_half_y = dims[0]/2.0
            full_grid_x = (full_grid_x - dims_half_x)/dims_half_x
            full_grid_y = (full_grid_y - dims_half_y)/dims_half_y
            full_grid = np.concatenate((np.expand_dims(full_grid_x, 2), np.expand_dims(full_grid_y, 2)), axis=2)
            img_coord_arr = tps.warpImage(full_grid.astype(np.float32), borderValue=-1024)
            displacement = img_coord_arr
            img_as_arr = np.concatenate((img_as_arr, displacement), 2)

        # convert back to PIL and return
        out_img = torch.from_numpy(img_as_arr).permute(2, 0, 1)
        return out_img

class Colorize(object):
    def __init__(self, n=20):
        self.cmap = labelcolormap(n)
        self.cmap = torch.from_numpy(self.cmap[:n])

    def __call__(self, gray_image):
            
        size = gray_image.shape
        color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0)

        idx=np.unique(gray_image)
        for label in idx:
            mask = (label == gray_image)
            color_image[0][mask] = self.cmap[label][0]
            color_image[1][mask] = self.cmap[label][1]
            color_image[2][mask] = self.cmap[label][2]
        color_image = color_image.float()/255.0 * 2 - 1
        return color_image  

def gram_matrix(input):
    a, b, c, d = input.size()  
    # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)

    features = input.view(a, b, c * d)  # resise F_XL into \hat F_XL

    G = torch.bmm(features, features.transpose(2,1))  # compute the gram product: B x C x C

    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return G.div(c * d*b)


def patch_gram_matrix(input, pose_map):
	# feat_map: batchsize x pose_joint_num x H x W
	# input: batchsize x C x H x W
	batchsize, no_joint, c, d = pose_map.size() # conditional map (e.g., semantic map)
	a, b, c, d = input.size() # batchsize x C x H x W

	patch_gram = []
	for i in range(pose_map.shape[1]):
		pose_map_ = pose_map[:,i] # batchsize x H x W
		pose_map_ = pose_map_.view(batchsize, 1, c * d) # batchsize, H x W

		masked_input = input.view(a, b, c * d)

		masked_input = pose_map_ * masked_input

		masked_input = masked_input.view(a,b,c,d)

		G = gram_matrix(masked_input) # batchsize x C x C

		patch_gram.append(G)

	output = torch.cat([_.unsqueeze(1) for _ in patch_gram], 1).contiguous()

	return output

def tensor2im(tensor,path,write=True):
    tensor=torch.clip(tensor,-1,1)
    tensor=tensor.cpu().numpy()
    im=(tensor+1.0)*0.5*255
    im=im.astype(np.uint8)
    im=np.transpose(im,axes=[1,2,0])
    if write:
        cv2.imwrite(path,im[...,::-1])
    return im
    

def iuv2smpluv(iuv,transformer):
    inds_png = iuv[..., 2]
    uv_png = iuv[..., :2]

    uv_transformed = transformer.GetGlobalUV(inds_png, uv_png)
    return uv_transformed
    
    
import torch
import torch.nn.functional as F
import os
import math

class TexTransformer:
    def __init__(self, path_to_texmap='/net/ivcfs4/mnt/data/nnli/tryon/coordinate_based_inpainting-main/data/smpltexmap.npy'):
        self.texmap = np.load(path_to_texmap)

    def GetGlobalUV(self, I, UV):
        I = I.reshape(-1)
        U = UV[:,:,0].reshape(-1)
        V = UV[:,:,1].reshape(-1)
        UV_global = -1*np.ones((I.shape[0], 2))
        for ci in range(1, 25):
            inds = (I==ci).nonzero()
            u = U[inds]
            v = V[inds]
            u_gl = self.texmap[ci-1][u, v, 0]
            v_gl = self.texmap[ci - 1][u, v, 1]
            v_gl = np.ones_like(v_gl)-2*v_gl
            u_gl = 2*u_gl - np.ones_like(u_gl)
            UV_global[inds, 0] = u_gl
            UV_global[inds, 1] = v_gl

        out = UV_global.reshape(UV.shape[0], UV.shape[1], 2)
        out=out.astype(np.float32)
        out[out==-1.] = np.nan

        return out

    def ShowSurrealTexture(self, UV_global, texpath):
        tex_img = io.imread(texpath)

        im_flat = np.zeros((UV_global.shape[0]*UV_global.shape[1], 3), dtype=np.uint8)
        uv_flat = UV_global.reshape(-1, 2)
        uv_flat = (tex_img.shape[0]-1)/2.0*(uv_flat + np.ones_like(uv_flat))
        uv_int = np.floor(uv_flat + 0.5*np.ones_like(uv_flat)).astype(int)
        tex_inds = np.nonzero(uv_flat[:, 0]>0)[0]
        uv_int = uv_int[tex_inds]
        im_flat[tex_inds, :] = tex_img[uv_int[:, 1], uv_int[:,0], :]
        out = im_flat.reshape(UV_global.shape[0], UV_global.shape[1], 3)
        
        return out
    
def make_colorwheel():
        '''
        Generates a color wheel for optical flow visualization as presented in:
            Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
            URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
        According to the C++ source code of Daniel Scharstein
        According to the Matlab source code of Deqing Sun
        '''
        RY = 15
        YG = 6
        GC = 4
        CB = 11
        BM = 13
        MR = 6

        ncols = RY + YG + GC + CB + BM + MR
        colorwheel = np.zeros((ncols, 3))
        col = 0

        # RY
        colorwheel[0:RY, 0] = 255
        colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
        col = col+RY
        # YG
        colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
        colorwheel[col:col+YG, 1] = 255
        col = col+YG
        # GC
        colorwheel[col:col+GC, 1] = 255
        colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
        col = col+GC
        # CB
        colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
        colorwheel[col:col+CB, 2] = 255
        col = col+CB
        # BM
        colorwheel[col:col+BM, 2] = 255
        colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
        col = col+BM
        # MR
        colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
        colorwheel[col:col+MR, 0] = 255
        return colorwheel


class flow2color():
# code from: https://github.com/tomrunia/OpticalFlow_Visualization
# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-08-03
    def __init__(self):
        self.colorwheel = make_colorwheel()


    def flow_compute_color(self, u, v, convert_to_bgr=False):
        '''
        Applies the flow color wheel to (possibly clipped) flow components u and v.
        According to the C++ source code of Daniel Scharstein
        According to the Matlab source code of Deqing Sun
        :param u: np.ndarray, input horizontal flow
        :param v: np.ndarray, input vertical flow
        :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
        :return:
        '''
        flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
        ncols = self.colorwheel.shape[0]

        rad = np.sqrt(np.square(u) + np.square(v))
        a = np.arctan2(-v, -u)/np.pi
        fk = (a+1) / 2*(ncols-1)
        k0 = np.floor(fk).astype(np.int32)
        k1 = k0 + 1
        k1[k1 == ncols] = 0
        f = fk - k0
        #print(np.unique(f))
        #assert 1==0
        for i in range(self.colorwheel.shape[1]):

            tmp = self.colorwheel[:,i]
            col0 = tmp[k0] / 255.0
            col1 = tmp[k1] / 255.0
            col = (1-f)*col0 + f*col1
            

            idx = (rad>10)#(rad <= 1)
            col[idx]  = 1 - rad[idx] * (1-col[idx])
            col[~idx] = col[~idx] * 0.75   # out of range?

            # Note the 2-i => BGR instead of RGB
            ch_idx = 2-i if convert_to_bgr else i
            flow_image[:,:,ch_idx] = np.floor(255 * col)

        return flow_image


    def __call__(self, flow_uv, clip_flow=None, convert_to_bgr=False):
        '''
        Expects a two dimensional flow image of shape [H,W,2]
        According to the C++ source code of Daniel Scharstein
        According to the Matlab source code of Deqing Sun
        :param flow_uv: np.ndarray of shape [H,W,2]
        :param clip_flow: float, maximum clipping value for flow
        :return:
        '''
        if len(flow_uv.size()) != 3:
            flow_uv = flow_uv[0]
        flow_uv = flow_uv.permute(1,2,0).cpu().detach().numpy()    

        assert flow_uv.ndim == 3, 'input flow must have three dimensions'
        assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'

        if clip_flow is not None:
            flow_uv = np.clip(flow_uv, 0, clip_flow)

        u = flow_uv[:,:,1]
        v = flow_uv[:,:,0]


        rad = np.sqrt(np.square(u) + np.square(v))
        rad_max = np.max(rad)

        epsilon = 1e-5
        #u = u / (rad_max + epsilon)
        #v = v / (rad_max + epsilon)
        image = self.flow_compute_color(u, v, convert_to_bgr) 
        image = torch.tensor(image).float().permute(2,0,1)/255.0 * 2 - 1
        return image
        
import torchvision.models.vgg as models
        
class VGG19(torch.nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        
        vgg19 = models.vgg19(pretrained=True)
        features = vgg19.features

        for param in features.parameters():
            param.requires_grad_(False)
        
        features.cuda()

        # features = models.vgg19(pretrained=True).features
        self.relu1_1 = torch.nn.Sequential()
        self.relu1_2 = torch.nn.Sequential()

        self.relu2_1 = torch.nn.Sequential()
        self.relu2_2 = torch.nn.Sequential()

        self.relu3_1 = torch.nn.Sequential()
        self.relu3_2 = torch.nn.Sequential()
        self.relu3_3 = torch.nn.Sequential()
        self.relu3_4 = torch.nn.Sequential()

        self.relu4_1 = torch.nn.Sequential()
        self.relu4_2 = torch.nn.Sequential()
        self.relu4_3 = torch.nn.Sequential()
        self.relu4_4 = torch.nn.Sequential()

        self.relu5_1 = torch.nn.Sequential()
        self.relu5_2 = torch.nn.Sequential()
        self.relu5_3 = torch.nn.Sequential()
        self.relu5_4 = torch.nn.Sequential()

        for x in range(2):
            self.relu1_1.add_module(str(x), features[x])

        for x in range(2, 4):
            self.relu1_2.add_module(str(x), features[x])

        for x in range(4, 7):
            self.relu2_1.add_module(str(x), features[x])

        for x in range(7, 9):
            self.relu2_2.add_module(str(x), features[x])

        for x in range(9, 12):
            self.relu3_1.add_module(str(x), features[x])

        for x in range(12, 14):
            self.relu3_2.add_module(str(x), features[x])

        for x in range(14, 16):
            self.relu3_2.add_module(str(x), features[x])

        for x in range(16, 18):
            self.relu3_4.add_module(str(x), features[x])

        for x in range(18, 21):
            self.relu4_1.add_module(str(x), features[x])

        for x in range(21, 23):
            self.relu4_2.add_module(str(x), features[x])

        for x in range(23, 25):
            self.relu4_3.add_module(str(x), features[x])

        for x in range(25, 27):
            self.relu4_4.add_module(str(x), features[x])

        for x in range(27, 30):
            self.relu5_1.add_module(str(x), features[x])

        for x in range(30, 32):
            self.relu5_2.add_module(str(x), features[x])

        for x in range(32, 34):
            self.relu5_3.add_module(str(x), features[x])

        for x in range(34, 36):
            self.relu5_4.add_module(str(x), features[x])
            
        mean = torch.FloatTensor([0.485, 0.456, 0.406])
        self.mean = mean.resize(1, 3, 1, 1).cuda()

        std = torch.FloatTensor([0.229, 0.224, 0.225])
        self.std = std.resize(1, 3, 1, 1).cuda()

        # don't need the gradients, just want the features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        
        x= (x + 1)/2 # [-1, 1] => [0, 1]
        x = (x - self.mean)/self.std
    
    
        relu1_1 = self.relu1_1(x)
        relu1_2 = self.relu1_2(relu1_1)

        relu2_1 = self.relu2_1(relu1_2)
        relu2_2 = self.relu2_2(relu2_1)

        relu3_1 = self.relu3_1(relu2_2)
        relu3_2 = self.relu3_2(relu3_1)
        relu3_3 = self.relu3_3(relu3_2)
        relu3_4 = self.relu3_4(relu3_3)

        relu4_1 = self.relu4_1(relu3_4)
        relu4_2 = self.relu4_2(relu4_1)
        relu4_3 = self.relu4_3(relu4_2)
        relu4_4 = self.relu4_4(relu4_3)

        relu5_1 = self.relu5_1(relu4_4)
        relu5_2 = self.relu5_2(relu5_1)
        relu5_3 = self.relu5_3(relu5_2)
        relu5_4 = self.relu5_4(relu5_3)

        out = {
            'relu1_1': relu1_1,
            'relu1_2': relu1_2,

            'relu2_1': relu2_1,
            'relu2_2': relu2_2,

            'relu3_1': relu3_1,
            'relu3_2': relu3_2,
            'relu3_3': relu3_3,
            'relu3_4': relu3_4,

            'relu4_1': relu4_1,
            'relu4_2': relu4_2,
            'relu4_3': relu4_3,
            'relu4_4': relu4_4,

            'relu5_1': relu5_1,
            'relu5_2': relu5_2,
            'relu5_3': relu5_3,
            'relu5_4': relu5_4,
        }
        return out
        
class EMAHelper(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = {}

    def register(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (
                    1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def state_dict(self):
        return self.shadow

    def load_state_dict(self, state_dict):
        self.shadow = state_dict

