import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import csv, os
import torch
import torch.nn.functional as F

def calculate_ssim(img1, img2, window_size=11, sigma=1.5):
    if not isinstance(img1, torch.Tensor):
        img1 = torch.from_numpy(img1).float()
    
    if not isinstance(img2, torch.Tensor):
        img2 = torch.from_numpy(img2).float()

    gaussian_kernel = torch.Tensor([np.exp(-(x - window_size//2)**2/float(2*sigma**2)) 
                                  for x in range(window_size)])
    gaussian_kernel = gaussian_kernel / gaussian_kernel.sum()
    
    gaussian_kernel_x = gaussian_kernel.unsqueeze(0).unsqueeze(2).unsqueeze(3)
    gaussian_kernel_y = gaussian_kernel.unsqueeze(0).unsqueeze(1).unsqueeze(3)
    gaussian_kernel_z = gaussian_kernel.unsqueeze(0).unsqueeze(1).unsqueeze(2)
    
    kernel_x = gaussian_kernel_x.expand(1, 1, window_size, 1, 1).to(img1.device)
    kernel_y = gaussian_kernel_y.expand(1, 1, 1, window_size, 1).to(img1.device)
    kernel_z = gaussian_kernel_z.expand(1, 1, 1, 1, window_size).to(img1.device)
    
    mu1 = F.conv3d(img1, kernel_x, padding=(window_size//2,0,0))
    mu1 = F.conv3d(mu1, kernel_y, padding=(0,window_size//2,0))
    mu1 = F.conv3d(mu1, kernel_z, padding=(0,0,window_size//2))
    
    mu2 = F.conv3d(img2, kernel_x, padding=(window_size//2,0,0))
    mu2 = F.conv3d(mu2, kernel_y, padding=(0,window_size//2,0))
    mu2 = F.conv3d(mu2, kernel_z, padding=(0,0,window_size//2))
    
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2
    
    sigma1_sq = F.conv3d(img1*img1, kernel_x, padding=(window_size//2,0,0))
    sigma1_sq = F.conv3d(sigma1_sq, kernel_y, padding=(0,window_size//2,0))
    sigma1_sq = F.conv3d(sigma1_sq, kernel_z, padding=(0,0,window_size//2)) - mu1_sq
    
    sigma2_sq = F.conv3d(img2*img2, kernel_x, padding=(window_size//2,0,0))
    sigma2_sq = F.conv3d(sigma2_sq, kernel_y, padding=(0,window_size//2,0))
    sigma2_sq = F.conv3d(sigma2_sq, kernel_z, padding=(0,0,window_size//2)) - mu2_sq
    
    sigma12 = F.conv3d(img1*img2, kernel_x, padding=(window_size//2,0,0))
    sigma12 = F.conv3d(sigma12, kernel_y, padding=(0,window_size//2,0))
    sigma12 = F.conv3d(sigma12, kernel_z, padding=(0,0,window_size//2)) - mu1_mu2
    
    C1 = 0.01**2
    C2 = 0.03**2
    
    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

def calculate_mae(img1, img2):
    return np.mean(np.abs(img1 - img2))

def calculate_metrics(pred, target):
    if pred.max() > 1.0:
        pred = pred / 255.0
    if target.max() > 1.0:
        target = target / 255.0
        
    psnr_value = psnr(target, pred, data_range=2.0)
    ssim_value = ssim(target, pred, multichannel=True, data_range=1.0)
    mae_value = calculate_mae(target, pred)
    
    return psnr_value, ssim_value, mae_value

def save_metrics(opt, metrics_list, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    csv_path = os.path.join(save_dir, 'metrics_results.csv')
    
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Image', 'PSNR', 'SSIM', 'MAE'])
        writer.writerows(metrics_list)
    
    avg_path = os.path.join(save_dir, 'metrics_average.txt')
    metrics_array = np.array([[float(x[1]), float(x[2]), float(x[3])] for x in metrics_list])
    avg_metrics = np.mean(metrics_array, axis=0)
    
    with open(avg_path, 'a') as f:
        f.write(f'Which Epoch: {opt.which_epoch}\n')
        f.write(f'Average PSNR: {avg_metrics[0]:.4f}\n')
        f.write(f'Average SSIM: {avg_metrics[1]:.4f}\n')
        f.write(f'Average MAE: {avg_metrics[2]:.4f}\n')

