import os
import torch
import scipy.io as scio
import time
import numpy as np
from torch.autograd import Variable
import datetime
import torch.nn.functional as F
import random
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from PIL import Image
from option import opt
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
from architecture import *
from utils import *

scaler = torch.GradScaler(device="cuda")

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if not torch.cuda.is_available():
    raise Exception('NO GPU!')

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    
set_seed(42)  

# init mask
mask3d_batch_test, input_mask_test = init_mask(opt.mask_path, opt.input_mask, 10)

# dataset
test_data = LoadTest(opt.test_path)

# saving path
date_time = str(datetime.datetime.now())
date_time = time2file_name(date_time)
result_path = opt.outf + opt.name + '/result/'
model_path = opt.outf + opt.name + '/model/'
if not os.path.exists(result_path):
    os.makedirs(result_path)
if not os.path.exists(model_path):
    os.makedirs(model_path)


opt.pretrained_model_path = 'checkpoint/CIDNet_9stg'
model = model_generator(opt.method, opt.pretrained_model_path).cuda()

def test():
    psnr_list, ssim_list = [], []
    test_gt = test_data.cuda().float()
    test_gt_chroma, test_gt_intensity = chroma_intensity_decom(test_gt)
    mask3d_batch_test_new = mask3d_batch_test*test_gt_intensity     
    input_meas = init_meas(test_gt_chroma, mask3d_batch_test_new, opt.input_setting)
    input_mask_test_new = chroma_intensity_mask(input_mask_test, test_gt_intensity, input_mask_sign=opt.input_mask)
    model.eval()
    begin = time.time()
    with torch.no_grad():
        model_out = model(input_meas, input_mask_test_new)
        model_out = model_out*test_gt_intensity
    
    end = time.time()    
    for k in range(test_gt.shape[0]):
        psnr_val = torch_psnr(model_out[k, :, :, :], test_gt[k, :, :, :])
        ssim_val = torch_ssim(model_out[k, :, :, :], test_gt[k, :, :, :])
        psnr_list.append(psnr_val.detach().cpu().numpy())
        ssim_list.append(ssim_val.detach().cpu().numpy())
    pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    truth = np.transpose(test_gt.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    psnr_mean = np.mean(np.asarray(psnr_list))
    ssim_mean = np.mean(np.asarray(ssim_list))
    # printing
    for i, (p, s) in enumerate(zip(psnr_list, ssim_list)):
        print(f"Sample {i+1}: PSNR = {p:.4f}, SSIM = {s:.6f}")
    print(f"\n==> Average PSNR: {psnr_mean:.4f}")
    print(f"==> Average SSIM: {ssim_mean:.6f}")
    return pred, truth, psnr_list, ssim_list, psnr_mean, ssim_mean

def main():
    (pred, truth, psnr_all, ssim_all, psnr_mean, ssim_mean) = test()


if __name__ == '__main__':
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    main()


