
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import device
from botorch.utils.transforms import normalize, unnormalize
# dtype = torch.double

def generate_initial_data(vae_model, score_model, n, d):

    # generate training data
    train_z = torch.rand(n, d, device=device)
    train_x = vae_model.decode(train_z)
    train_obj = score_func(score_model, train_x).unsqueeze(-1)
    return train_x, train_obj


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def score(y):
    """Returns a 'score' for each digit from 0 to 9. It is modeled as a squared exponential
    centered at the digit '3'.
    """
    return torch.exp(-2 * (y - 3) ** 2)

def score_func(cnn_model, x):
    """The input x is an image and an expected score
    based on the CNN classifier and the scoring
    function is returned.
    """
    with torch.no_grad():
        x = x.view(x.shape[0], 1, 28, 28)
        probs = torch.exp(cnn_model(x))  # b x 10
        scores = score(
            torch.arange(10, device=device)
        ).expand(probs.shape)
    return (probs * scores).sum(dim=1)

def init_model():
    cnn_weights_path = "pretrained_models/mnist_cnn.pt"
    cnn_model = Net().to(device=device)
    cnn_state_dict = torch.load(cnn_weights_path, map_location=device, weights_only=True)
    cnn_model.load_state_dict(cnn_state_dict)

    return cnn_model