import os
import sys
import torch
import torch.nn as nn
import numpy

from PIL import Image
from criteria.dwt import HaarTransform, InverseHaarTransform
from criteria.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix

def scale_image(image, scale):
    # l = 1 / (scale ** 0.5) * torch.ones(1, scale).cuda()
    l = 1 / (scale) * torch.ones(1, scale).cuda()
    ll = l.T * l
    return upfirdn2d(image, ll, down=scale)

class Hierarchical_Wavelet_Loss(nn.Module):

    def __init__(self, wrange=[i for i in range(8)], kind='l1', scale=1):
        super(Hierarchical_Wavelet_Loss, self).__init__()

        self.wrange=wrange
        self.scale=scale
        if kind=='l1':
            self.loss=nn.L1Loss().cuda()
        elif kind=='l2':
            self.loss=nn.MSELoss().cuda()
        self.dwt = HaarTransform(3).cuda()

    def forward(self, image1, image2):
        for i in range(8):
            if i==0:
                l1=image1
                l2=image2
                loss=0
            else:
                l1=ll1
                l2=ll2
            ll1, lh1, hl1, hh1 = self.dwt(l1).chunk(4, 1)
            ll2, lh2, hl2, hh2 = self.dwt(l2).chunk(4, 1)

            if i not in self.wrange:
                continue
            else:
                # loss+=(self.loss(ll1,ll2)+self.loss(lh1,lh2)+self.loss(hl1,hl2)+self.loss(hh1,hh2))/4
                if self.scale!=1:
                    lh1=scale_image(lh1,self.scale)
                    hl1=scale_image(hl1,self.scale)
                    hh1=scale_image(hh1,self.scale)

                    lh2=scale_image(lh2,self.scale)
                    hl2=scale_image(hl2,self.scale)
                    hh2=scale_image(hh2,self.scale)

                loss+=(self.loss(lh1,lh2)+self.loss(hl1,hl2)+self.loss(hh1,hh2)/(2**(i-1)))
        return loss

