import torch
import torch.nn as nn
from transforms.image_transforms import holz_transform
import torch.nn.functional as F
import numpy as np
import torchvision
from IPython import embed

class HolzClassifier(nn.Module):
    def __init__(self, dct_mean, dct_var, freq_mask):
        super(HolzClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
        self.pool1 = nn.AvgPool2d(2)
        self.conv3 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool2 = nn.AvgPool2d(2)
        self.fc1 = nn.Linear(32 * 32 * 32, 2)   # 32 * image dimension * image dimension

        self.register_buffer('dct_mean', dct_mean, False)
        self.register_buffer('dct_var', dct_var, False)
        self.register_buffer('freq_mask', freq_mask, False)
        #self.dct_mean = dct_mean
        #self.dct_var = dct_var
        #self.freq_mask = freq_mask

    def transform(self, x):
        tensor = holz_transform(x)
        tensor = (tensor - self.dct_mean) / torch.sqrt(self.dct_var)
        tensor = tensor * self.freq_mask
        return tensor
    
    def forward_sans(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        output = self.fc1(x)
        return output

    def forward(self, x):
        x = self.transform(x)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        output = self.fc1(x)
        return output

class autoattack_wrapper:
    def __init__(self, ensemble_list):
        self.ensemble_list = ensemble_list

    def forward(self, x):
        logits_list = []
        for model in self.ensemble_list:
            logits = model(x)
            logits_list.append(logits)

        avg_logits = torch.log(torch.mean(torch.stack([F.softmax(l, dim=-1) for l in logits_list]), dim=0))

        return avg_logits

