import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from omegaconf import OmegaConf
from utils import instantiate_from_config
import math

class Model(nn.Module):
    def __init__(self, flatten_dim=200, hidden_len=128):
        super(Model, self).__init__()
        self.flatten_dim = flatten_dim
        self.hidden_len = hidden_len
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flatten_dim, hidden_len),
            # nn.ReLU(),
            nn.LeakyReLU(),
            nn.Linear(hidden_len, 2),
            # nn.ReLU(),
            # nn.Linear(512, 256),
            # nn.ReLU(),
            # nn.Linear(64, 2)
        )

    def forward(self, x):
        return self.model(x)


def undetectability_test(mapping_method, train_batch_size, test_batch_size, device):
    model = Model(flatten_dim=math.prod(tuple(mapping_method.latent_shape))).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    mapping_method.batch_size = train_batch_size
    for epoch in range(1500):
        with torch.no_grad():
            if hasattr(mapping_method, 'lm'):
                random_message = torch.randint(0, 2, (train_batch_size, mapping_method.lm)).to(device)
            else:
                random_message = None
            fake_data = mapping_method.embed_watermark(random_message).to(torch.float32)

            real_data = torch.randn_like(fake_data).to(fake_data)
            x = torch.cat((fake_data, real_data), dim=0)
            y = [[0, 1]] * train_batch_size + [[1, 0]] * train_batch_size
            y = torch.tensor(y).to(x)
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch + 1}, Loss: {loss.item()}, mean: {x.mean()}, std: {x.std()}")

    acc_list = []
    mapping_method.batch_size = test_batch_size
    with torch.no_grad():
        for epoch in range(1000):
            if hasattr(mapping_method, 'lm'):
                random_message = torch.randint(0, 2, (test_batch_size, mapping_method.lm)).to(device)
            else:
                random_message = None
            fake_data = mapping_method.embed_watermark(random_message).to(torch.float32)
            real_data = torch.randn_like(fake_data).to(fake_data)
            x = torch.cat((fake_data, real_data), dim=0)
            y = [[0, 1]] * test_batch_size + [[1, 0]] * test_batch_size
            y = torch.tensor(y).to(x)
            prediction = model(x)
            prediction = torch.sigmoid(prediction)
            prediction = torch.argmax(prediction, dim=1)
            y = y[:, 1]
            acc = y.eq(prediction).float().mean().item()
            acc_list.append(acc)
            print(acc)
        print('Average detection accuracy: ', sum(acc_list) / len(acc_list))
    return sum(acc_list) / len(acc_list)



if __name__ == '__main__':
    device = 'cuda:0'
    train_batch_size = 16
    test_batch_size = 16
    conf = OmegaConf.load('./options/ours_test.yaml')
    mapping_method_opt = conf.MappingModule
    mapping_method_opt.opts.device = device
    mapping_method_opt.opts.latent_shape = (4, 64, 64)
    mapping_method_opt.opts.batch_size = train_batch_size

    with torch.no_grad():
        mapping_method = instantiate_from_config(mapping_method_opt)
    undetectability_test(mapping_method, train_batch_size, test_batch_size, device)
