import torch.nn as nn 
import numpy as np 
import torch
class image_reversion(nn.Module):
    def __init__(self,batch=1,channel=4,height=512,width=512,device='cuda') -> None:
        super(image_reversion,self).__init__()
        self.height = height
        self.width = width
        self.device = device
        self.batch=batch
        self.image=nn.Parameter(torch.randn(batch,channel,self.height,self.width),requires_grad=True)
    def to_image_numpy(self):
        #renormalization
        image_numpy=self.image.detach().cpu().numpy()
        image_numpy=image_numpy.transpose(0,2,3,1)
        image_numpy=image_numpy*np.array([0.229,0.224,0.225])+np.array([0.485,0.456,0.406])
        image_numpy=np.clip(image_numpy,0,1)
        return image_numpy
    def paprameter_with_image(self,image):
        if(image.shape[1]>10):
            image=np.transpose(image,(0,3,1,2))
        if(type(image)==np.ndarray):
            image=torch.tensor(image)
        if(len(image.shape)==3):
            image=image.unsqueeze(0)
        if(image.shape[1]==1):
            image=torch.cat([image,image,image],dim=1)
        if(image.shape[1]==4):
            image=image[:,:3,:,:]
        self.image=nn.Parameter(image.to(self.device),requires_grad=True)
    def vb(self):
        # calculate the sum of the image after applying 
        # the convolution of the following kernel: [[2,-1],[-1,0]] for each channel
        #print(convolution.shape)
        diff_h = self.image[:, :, :, :-1] - self.image[:, :, :, 1:]
    
        # Vertical differences
        diff_v = self.image[:, :, :-1, :] - self.image[:, :, 1:, :]
        
        # Compute squared differences
        squared_diff_h = diff_h ** 2
        squared_diff_v = diff_v ** 2
        
        # Sum of squared differences
        tv_loss = squared_diff_h.mean() + squared_diff_v.mean()
        return tv_loss
    def calculate_loss(self,vector_image,target_vector,decent=True):
        square_loss=torch.square(vector_image-target_vector).sum()/torch.sum(target_vector)
        if(decent):
            return square_loss+self.vb()
        else:
            return -square_loss+self.vb()