import torch
from torch import nn




class WhiteBoxModel(nn.Module):
    def __init__(self,real_answer,version="square",device="cpu"):
        super(WhiteBoxModel, self).__init__()
        self.t1=4
        self.t2=13
        self.version=version
        self.real_answer=real_answer
        self.alpa=5
        self.device = device
        # 모델 초기화 코드

    def forward(self, x):
        if self.version=="only_time":
            tmp = x[:, :, self.t1:self.t2+1, :, :]
            output= torch.sum(tmp ** 2)

        elif self.version=="solid":
            tmp=x[:, :, self.t1:self.t2+1 , int(x.shape[3]/4): int(x.shape[3]/4)*3, int(x.shape[4]/3): int(x.shape[4]/3)*2]
            output= torch.sum(tmp**2)

        elif self.version=="moving_diagonal":
            result=0
            for i in range(self.t1, self.t2+1):
                ratio=10
                start_h = 0 + (i-self.t1)*ratio
                end_h = start_h+(int(x.shape[3] / 4) * 3-int(x.shape[3] / 4))
                start_w = 0 + (i-self.t1)*ratio
                end_w = start_w+(int(x.shape[4] / 3) * 2-int(x.shape[4] / 3))


                tmp=x[:, :, i, start_h:end_h, start_w:end_w]

                result+=torch.sum(tmp**2)
            output= result

        elif self.version=="solid_time":
            # wee=[1.0]*3+[0.1]*10+[1.0]*2
            wee=1
            output= 0
            result=[]
            for i in range(self.t1,self.t2+1):

                if self.t1<=i<self.t2+1:
                    ratio=10
                    start_h = 0 + (i-self.t1)*ratio
                    end_h = start_h+(int(x.shape[3] / 4) * 3-int(x.shape[3] / 4))
                    start_w = 0 + (i-self.t1)*ratio
                    end_w = start_w+(int(x.shape[4] / 3) * 2-int(x.shape[4] / 3))

                    # print(i,"==",start_w,end_w,start_h,end_h)
                    tmp=x[:, :, i, start_h:end_h, start_w:end_w]
                    # result.append(torch.sum(tmp**2).item())
                    if i==7 or i==8:

                        tmp_x=x.clone()
                        tmp_x[:, :, i, start_h:end_h, start_w:end_w]=0
                        result.append(self.alpa*torch.sum(tmp_x[:,:,i,:,:]**2))

                        # result.append(self.alpa*torch.sum((torch.sum(x[:,:,i,:,:])-torch.sum(tmp))**2))
                    # elif i==4 or i==5:
                    #     result.append(torch.sum(tmp**2).item()*7)
                    else:
                        result.append(torch.sum(tmp**2))
                # else:
                #     result.append(torch.sum(x[:,:,i,:,:]))

            for i in range(len(result)):
                output+=result[i]

            # output+=torch.sum(x)

        self.tmp_out=output
        self.data_optimal=result
        # print("전체",torch.sum(x))
        # print("out=",output)

        cl=torch.zeros(2)
        difference = abs(output - self.real_answer)
        max_difference = max(abs(output), abs(self.real_answer))
        differ=(difference / max_difference)
        similarity = 1 - differ
        cl[0]=differ
        cl[1]=similarity

        return cl.reshape(1,-1).to(self.device)