import sys
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"
import numpy as np
import numpy as np
import torch
from dncnn import DnCNN
from torch.utils.data.dataloader import DataLoader
from datasets_n2self import TrainDataset_Supervised, TestDataset
import cv2


device = torch.device('cuda')

model_names = ["trained_models/noise2true.pt","trained_models/noise2self.pt","trained_models/ours_ssrl_noise2self.pt","trained_models/neighbor2neighbor.pt","trained_models/ours_ssrl_neighbor2neighbor.pt"]
main_paths = ["Noise2True", "Noise2Self","Proposed SSRL in Noise2Self setup","Neighbor2Neighbor","Proposed SSRL in Neighbor2Neighbor setup"]
for method_ind, model_name in enumerate(model_names):
    model = DnCNN(3, num_of_layers = 17)
    model.load_state_dict(torch.load(model_name))
    model = model.to(device)
    model.eval()

    best_psnrs = []
    best_val_loss = 1

    total_dirname = "./data/Set5"
#    total_dirname = "./data/BSD300_test"


    eval_dataset = TestDataset(total_dirname)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
    total_i = 0
    psnrs = 0
    test_ind = 0

    psnrs_ = []
    for ind,(noisy,clean) in enumerate(eval_dataloader):
        b,c,h,w = noisy.shape
        noisy = noisy.to(device)/255
        clean = clean/255

        clean = np.clip(clean.cpu().numpy(), 0, 1).astype(np.float64)
        output = model(noisy).detach()

        denoised = np.clip(output.detach().cpu().numpy(), 0, 1).astype(np.float64)

        mse = np.mean(np.square(denoised*255 - clean*255))
        psnr = 20 * np.log10(255) - 10 * np.log10(mse)

        psnrs_.append(psnr)
        psnrs += psnr
        test_ind += 1
        output = np.transpose(output.cpu().numpy()[0],(1,2,0))
        path = "predicted_outputs/%s_%d.png"%(model_name[14:-3],ind)
        cv2.imwrite(path,output[:,:,[2,1,0]]*255)
        
        clean = np.transpose(clean[0],(1,2,0))
        path = "predicted_outputs/clean_%d.png"%(ind)
        cv2.imwrite(path,clean[:,:,[2,1,0]]*255)
        
    best_psnr = psnrs / test_ind
    print("%s PSNR : %f"%(main_paths[method_ind],best_psnr))
    
model_namess = ["trained_models/noise2same.pt","trained_models/ours_ssrl_noise2same.pt"]
main_paths_n2same = ["Noise2Same", "Proposed SSRL in Noise2Same setup"]

for method_ind, model_name in enumerate(model_namess):
    model = DnCNN(3, num_of_layers = 17)
    model.load_state_dict(torch.load(model_name))
    model = model.to(device)
    model.eval()

    best_psnrs = []
    best_val_loss = 1




    eval_dataset = TestDataset(total_dirname)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
    total_i = 0
    psnrs = 0
    test_ind = 0

    psnrs_ = []
    for ind,(noisy,clean) in enumerate(eval_dataloader):
        b,c,h,w = noisy.shape
        noisy = noisy.to(device)
        clean = clean
        noisy_n = noisy.clone().view(1,3,-1)
        noisy_mean = torch.mean(noisy_n,-1).unsqueeze(-1).unsqueeze(-1)
        noisy_std = torch.std(noisy_n,-1).unsqueeze(-1).unsqueeze(-1)
        noisy = (noisy-noisy_mean)/noisy_std
                
                
        clean =clean.cpu().numpy()
        output = (model(noisy).detach()*noisy_std)+noisy_mean

        denoised = output.detach().cpu().numpy()

        mse = np.mean(np.square(denoised - clean))
        psnr = 20 * np.log10(255) - 10 * np.log10(mse)

        psnrs_.append(psnr)
        psnrs += psnr
        test_ind += 1
        output = np.transpose(output.cpu().numpy()[0],(1,2,0))
        path = "predicted_outputs/%s_%d.png"%(model_name[14:-3],ind)
        cv2.imwrite(path,output[:,:,[2,1,0]])

    best_psnr = psnrs / test_ind
    print("%s PSNR : %f"%(main_paths_n2same[method_ind],best_psnr))
    
    
