import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.models as models
import math
from utils import warp

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Ternary(nn.Module):
    def __init__(self, patch_size=7):
        super(Ternary, self).__init__()
        self.patch_size = patch_size
        out_channels = patch_size * patch_size
        self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
        self.w = np.transpose(self.w, (3, 2, 0, 1))
        self.w = torch.tensor(self.w).float().to(device)

    def transform(self, tensor):
        tensor_ = tensor.mean(dim=1, keepdim=True)
        patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None)
        loc_diff = patches - tensor_
        loc_diff_norm = loc_diff / torch.sqrt(0.81 + loc_diff ** 2)
        return loc_diff_norm

    def valid_mask(self, tensor):
        padding = self.patch_size//2
        b, c, h, w = tensor.size()
        inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor)
        mask = F.pad(inner, [padding] * 4)
        return mask
        
    def forward(self, x, y):
        loc_diff_x = self.transform(x)
        loc_diff_y = self.transform(y)
        diff = loc_diff_x - loc_diff_y.detach()
        dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True)
        mask = self.valid_mask(x)
        loss = (dist * mask).mean()
        return loss


class Geometry(nn.Module):
    def __init__(self, patch_size=3):
        super(Geometry, self).__init__()
        self.patch_size = patch_size
        out_channels = patch_size * patch_size
        self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
        self.w = np.transpose(self.w, (3, 2, 0, 1))
        self.w = torch.tensor(self.w).float().to(device)

    def transform(self, tensor):
        b, c, h, w = tensor.size()
        tensor_ = tensor.reshape(b*c, 1, h, w)
        patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None)
        loc_diff = patches - tensor_
        loc_diff_ = loc_diff.reshape(b, c*(self.patch_size**2), h, w)
        loc_diff_norm = loc_diff_ / torch.sqrt(0.81 + loc_diff_ ** 2)
        return loc_diff_norm

    def valid_mask(self, tensor):
        padding = self.patch_size//2
        b, c, h, w = tensor.size()
        inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor)
        mask = F.pad(inner, [padding] * 4)
        return mask

    def forward(self, x, y):
        loc_diff_x = self.transform(x)
        loc_diff_y = self.transform(y)
        diff = loc_diff_x - loc_diff_y
        dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True)
        mask = self.valid_mask(x)
        loss = (dist * mask).mean()
        return loss


class Charbonnier_L1(nn.Module):
    def __init__(self):
        super(Charbonnier_L1, self).__init__()

    def forward(self, diff, mask=None):
        if mask is None:
            loss = ((diff ** 2 + 1e-6) ** 0.5).mean()
        else:
            loss = (((diff ** 2 + 1e-6) ** 0.5) * mask).mean() / (mask.mean() + 1e-9)
        return loss


class Charbonnier_Ada(nn.Module):
    def __init__(self):
        super(Charbonnier_Ada, self).__init__()

    def forward(self, diff, weight):
        alpha = weight / 2
        epsilon = 10 ** (-(10 * weight - 1) / 3)
        loss = ((diff ** 2 + epsilon ** 2) ** alpha).mean()
        return loss


class MeanShift(nn.Conv2d):
    def __init__(self, data_mean, data_std, data_range=1, norm=True):
        c = len(data_mean)
        super(MeanShift, self).__init__(c, c, kernel_size=1)
        std = torch.Tensor(data_std)
        self.weight.data = torch.eye(c).view(c, c, 1, 1)
        if norm:
            self.weight.data.div_(std.view(c, 1, 1, 1))
            self.bias.data = -1 * data_range * torch.Tensor(data_mean)
            self.bias.data.div_(std)
        else:
            self.weight.data.mul_(std.view(c, 1, 1, 1))
            self.bias.data = data_range * torch.Tensor(data_mean)
        self.requires_grad = False



class GramPerceptualLoss(nn.Module):
    def __init__(self):
        super(GramPerceptualLoss, self).__init__()
        blocks = []
        pretrained = True
        self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
        self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
        for param in self.parameters():
            param.requires_grad = False
    
    def forward(self, X, Y, mask):
        X = self.normalize(X)
        Y = self.normalize(Y)
        indices = [12]
        weights = [1.0/3.7]
        k = 0
        loss = 0
        for i in range(indices[-1]):
            X = self.vgg_pretrained_features[i](X)
            Y = self.vgg_pretrained_features[i](Y)
            if (i+1) in indices:
                _, C, H, W = X.size()
                for m in range(mask.shape[0]):
                    for n in range(mask.shape[1]):
                        if mask[m,n].sum() == 0: 
                            continue
                        single_mask = mask[m:m+1, n:n+1]
                        single_mask = F.interpolate(single_mask, size=(H, W), mode='nearest')
            
                        single_masked_x = X[m, :, single_mask[0, 0]>0]
                        single_masked_y = Y[m, :, single_mask[0, 0]>0]

                        if single_masked_x.size(1) == 0 or single_masked_y.size(1) == 0:
                            continue

                        gram_masked_x = torch.mm(torch.transpose(single_masked_x, 0, 1), single_masked_x)
                        gram_masked_y = torch.mm(torch.transpose(single_masked_y, 0, 1), single_masked_y)

                        loss += (gram_masked_x - gram_masked_y).abs().mean() * weights[k]
                k += 1
        return loss



class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        pretrained = True
        self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
        self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
        for param in self.parameters():
            param.requires_grad = False
    
    def forward(self, X, Y, mask):
        X = self.normalize(X)
        Y = self.normalize(Y)
        indices = [2, 7, 12, 21, 30]
        weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
        k = 0
        loss = 0
        for i in range(indices[-1]):
            X = self.vgg_pretrained_features[i](X)
            Y = self.vgg_pretrained_features[i](Y)
            if (i+1) in indices:
                _, C, H, W = X.size()
                for m in range(mask.shape[0]):
                    for n in range(mask.shape[1]):
                        single_mask = mask[m:m+1, n:n+1]
                        single_mask = F.interpolate(single_mask, size=(H, W), mode='nearest')
                        loss += ((X[m:m+1] - Y[m:m+1]).abs() * single_mask).mean() / (single_mask.mean() + 1e-9) * weights[k]
                k += 1
        return loss / (mask.shape[0] * mask.shape[1])



def gradient(data):
    D_dy = data[:, :, 1:, :] - data[:, :, :-1, :]
    D_dx = data[:, :, :, 1:] - data[:, :, :, :-1]
    return D_dx, D_dy

'''
class WarpLoss(nn.Module):
    def __init__(self):
        super(WarpLoss, self).__init__()
    
    def forward(self, warped_img1, warp_img2, imgt, mask):  
        loss = torch.tensor(0.0).cuda()

        img1_dx, img1_dy = gradient(warped_img1)
        img2_dx, img2_dy = gradient(warped_img2)

        imgt_dx, imgt_dy = gradient(imgt)

        for b in range(mask.shape[0]):
            for n in range(mask.shape[1]):
                if mask[b,n].sum() == 0: 
                    continue
                weights_y = (mask[b:b+1, n, :, 1:, :] - mask[b:b+1, n, :, :-1, :] == 0).float()
                weights_x = (mask[b:b+1, n, :, :, 1:] - mask[b:b+1, n, :, :, :-1] == 0).float()

                loss_x = weights_x * (imgt_dx[b:b+1, :, :, :] - img1_dx[b:b+1, :, :, :]).abs() + weights_x * (imgt_dx[b:b+1, :, :, :] - img2_dx[b:b+1, :, :, :]).abs()
                loss_y = weights_y * (imgt_dy[b:b+1, :, :, :] - img1_dy[b:b+1, :, :, :]).abs() + weights_y * (imgt_dy[b:b+1, :, :, :] - img2_dy[b:b+1, :, :, :]).abs()

                loss += loss_x.mean() / 2.0 + loss_y.mean() / 2.0

        return loss / (mask.shape[0])
'''
class WarpLoss(nn.Module):
    def __init__(self):
        super(WarpLoss, self).__init__()
    
    def forward(self, warped_img1, warp_img2, imgt, mask):  
        loss = torch.tensor(0.0).cuda()

        img1_dx, img1_dy = gradient(warped_img1)
        img2_dx, img2_dy = gradient(warped_img2)

        imgt_dx, imgt_dy = gradient(imgt)

        loss_x = (img1_dx - imgt_dx).abs()
        loss_y = (img1_dy - imgt_dy).abs()
        loss += loss_x.mean() / 2.0 + loss_y.mean() / 2.0

        loss_x = (img2_dx - imgt_dx).abs()
        loss_y = (img2_dy - imgt_dy).abs()
        loss += loss_x.mean() / 2.0 + loss_y.mean() / 2.0

        return loss


class ImageLoss(nn.Module):
    def __init__(self):
        super(ImageLoss, self).__init__()
    
    def forward(self, pred_image, image, mask):
        loss = torch.tensor(0.0, device=mask.device)
        B, N, H, W = mask.size()
        
        for b in range(B):
            loss_image = []
            mask_left = torch.ones(1, 1, H, W, device=mask.device)
            diff = pred_image[b:b+1] - image[b:b+1]
            
            for n in range(N):
                if mask[b, n].sum() == 0: 
                    continue
                single_mask = mask[b:b+1, n:n+1]
                mask_left = mask_left - single_mask
                loss1 = (((diff ** 2 + 1e-6) ** 0.5) * single_mask).mean() / (single_mask.mean() + 1e-9)
                loss_image.append(loss1)
            
            # Calculate loss for the remaining region (not covered by any mask)
            single_mask_left = torch.where(mask_left > 0, torch.tensor(1.0, device=mask.device), torch.tensor(0.0, device=mask.device))
            if single_mask_left.sum() > 0:
                loss1 = (((diff ** 2 + 1e-6) ** 0.5) * single_mask_left).mean() / (single_mask_left.mean() + 1e-9)
                loss_image.append(loss1)
            
            if loss_image:  # Ensure there's at least one element in loss_image
                loss_image_tensor = torch.stack(loss_image)
                loss_weights = F.softmax(loss_image_tensor, dim=0)
                weighted_loss = torch.sum(loss_weights * loss_image_tensor)
                loss = loss + weighted_loss
        
        return loss / B  # Average over batch size



