from torch.utils.data import DataLoader
from utils import *
from network.Network import *
import torchvision
from utils.load_test_setting import *

import os, random, numpy as np, torch

SEED = 1 

def set_seed(seed: int = 42, deterministic: bool = True):
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")

    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True, warn_only=True)
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False

    g = torch.Generator()
    g.manual_seed(seed)
    return g

def seed_worker(worker_id: int):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

GEN = set_seed(SEED)
# ========= Repro seeds =========

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

network = Network(noise_layers=noise_layers,
                  device=device,
                  lr=0.1)

EC_path = result_folder + "models/EC_" + str(model_epoch) + ".pth"
network.load_model(EC_path)

test_dataset = CoCoDataset(os.path.join(dataset_path, "test"), H, W)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

print("\nStart Testing : \n\n")

test_result = {
	"acc": 0.0,
	"psnr": 0.0,
}

start_time = time.time()

saved_iterations = np.random.choice(np.arange(len(test_dataset)), size=save_images_number, replace=False)

num = 0




for t in range(1):
	res_images = 0
	for idx, images in enumerate(test_dataloader): 
		images = images.to(device)
		message = torch.Tensor(np.random.choice([0, 1], (1, message_length))).to(device)

		network.encoder_decoder.eval()

		with torch.no_grad():
			images, messages = images.to(network.device), message.to(network.device)
			
			encoded_images = network.encoder_decoder.encoder(images, messages)
			encoded_images = images + (encoded_images - images) * strength_factor
		
			noised_images = network.encoder_decoder.noise([encoded_images, images])
	
			decoded_messages = network.encoder_decoder.decoder(noised_images)

			psnr = kornia.losses.psnr_loss(encoded_images.detach(), images, 2).item()

		acc,_ = network._bit_accuracy_and_error(messages, decoded_messages)

		result = {
			"acc": acc,
			"psnr": psnr,
		}

		for key in result:
			test_result[key] += float(result[key])

		num += 1


content = "Average : \n"
for key in test_result:
	content += key + "=" + str(test_result[key] / num) + ","
content += "\n"

with open(test_log, "a") as file:
	file.write(content)

print(content)
