from metrics import PSNR, SSIM, MSE, L_inf_dist, MAE, MSSSIM, vmaf, niqe
import torch 
import numpy as np

# fr cols
FR_COLS = ['msssim', 'ssim', 'mse', 'psnr', 'l_inf', 'mae', 'vmaf']

NR_COLS = ['niqe']

def evaluate_fr_metrics(src, dest, device='cuda:0'):
    """Calculates full-reference image quality metrics between source and destination images.

    Args:
        src (torch.Tensor): The source/reference images (batch of tensors, shape [B, 3, H, W]).
        dest (torch.Tensor): The destination/distorted images (batch of tensors, shape [B, 3, H, W]).
        device (str, optional): The device to perform calculations on. Defaults to 'cuda:0'.

    Returns:
        dict: A dictionary where keys are metric names and values are lists of metric scores for each image in the batch.
    """
    fr_vals = {
                'mse':[],
                'mae':[],
                'ssim':[],
                'psnr':[],
                'l_inf':[],
                'msssim':[],
                'vmaf':[]
            }
    for i in range(len(src)):
        fr_vals['mse'].append(MSE(src[i].clone().unsqueeze(0).to(device), dest[i].clone().unsqueeze(0).to(device)))
        fr_vals['mae'].append(MAE(src[i].clone().unsqueeze(0).to(device), dest[i].clone().unsqueeze(0).to(device)))
        fr_vals['ssim'].append(SSIM(src[i].clone().unsqueeze(0).to(device), dest[i].clone().unsqueeze(0).to(device)))
        fr_vals['psnr'].append(PSNR(src[i].clone().unsqueeze(0).to(device), dest[i].clone().unsqueeze(0).to(device)))
        fr_vals['l_inf'].append(L_inf_dist(src[i].clone().unsqueeze(0).to(device), dest[i].clone().unsqueeze(0).to(device)))
        fr_vals['msssim'].append(MSSSIM(src[i].clone().unsqueeze(0).to(device), dest[i].clone().unsqueeze(0).to(device)))
        fr_vals['vmaf'].append(vmaf(src[i].clone().unsqueeze(0).to(device), dest[i].clone().unsqueeze(0).to(device)))
    return fr_vals


def evaluate_nr_metrics(src, nr_models, device='cuda:0'):
    """Calculates no-reference image quality metrics for a batch of images.

    Args:
        src (torch.Tensor): A batch of source images, expected shape [N, C, H, W].
        nr_models (dict): A dictionary mapping NR metric names to their corresponding torch models.
        device (str, optional): The device to perform calculations on. Defaults to 'cuda:0'.

    Returns:
        dict: A dictionary where keys are metric names and values are lists of metric scores for each image in the batch.
    """
    nr_vals = {k:[] for k in NR_COLS}
    for i in range(len(src)):
        for k in NR_COLS:
            val = nr_models[k](src[i].clone().unsqueeze(0).to(device))
            if val is not None:
                val = val.cpu().item()
            else:
                val = np.nan
            nr_vals[k].append(val)
    return nr_vals

def prepare_imgs(img_dict, device):
    """Ensures all necessary image keys exist in the dictionary and moves tensors to the specified device.

    Initializes missing image tensors to None.

    Args:
        img_dict (dict): A dictionary of image tensors.
        device (str): The target device for the tensors.

    Returns:
        dict: The prepared image dictionary.
    """
    for name in ['clear', 'attacked', 'rec_def_clear', 'rec_def_attacked', 'rec_undef_clear', 'rec_undef_attacked']:
        if name not in img_dict.keys():
            img_dict[name] = None
    for n in img_dict.keys():
        if img_dict[n] is not None:
            img_dict[n] = img_dict[n].to(device)
    return img_dict


def evaluate_codec_image_quality(img_dict, nr_models={}, device='cuda:0'):
    """Performs a comprehensive image quality assessment for various codec evaluation scenarios.

    Calculates both full-reference and no-reference metrics for different combinations of
    original, attacked, and reconstructed (defended/undefended) images.

    Args:
        img_dict (dict): A dictionary containing different versions of an image
                         ('clear', 'attacked', 'rec_def_clear', etc.).
        nr_models (dict, optional): A dictionary of no-reference metric models. Defaults to {}.
        device (str, optional): The device for calculations. Defaults to 'cuda:0'.

    Returns:
        dict: A dictionary containing the mean scores for all calculated IQA metrics.
    """
    img_dict = prepare_imgs(img_dict, device)
        
    if img_dict['clear'] is not None and img_dict['rec_def_clear'] is not None:
            fr_clear_rec_def_clear = evaluate_fr_metrics(img_dict['clear'], img_dict['rec_def_clear'], device=device)
    else:
        fr_clear_rec_def_clear = { x: np.nan for x in FR_COLS}

    if img_dict['rec_def_attacked'] is not None:
        # attacked images vs reconstructed attacked (with defence as preprocessing, if provided)
        fr_attacked_rec_def_attacked = evaluate_fr_metrics(img_dict['attacked'], img_dict['rec_def_attacked'], device=device)
    else:
        fr_attacked_rec_def_attacked = { x: np.nan for x in FR_COLS}
    
    # clear vs attacked, wo reconstruction
    fr_clear_attacked = evaluate_fr_metrics(img_dict['clear'], img_dict['attacked'])

    if img_dict['rec_undef_clear'] is not None:
        # clear vs reconstructed clear (WITHOUT defence)
        fr_clear_rec_clear = evaluate_fr_metrics(img_dict['clear'], img_dict['rec_undef_clear'], device=device)
    else:
        fr_clear_rec_clear = { x: np.nan for x in FR_COLS}
    
    if img_dict['rec_undef_attacked'] is not None:
        # clear vs reconstructed clear (WITHOUT defence)
        fr_attacked_rec_attacked = evaluate_fr_metrics(img_dict['attacked'], img_dict['rec_undef_attacked'], device=device)
    else:
        fr_attacked_rec_attacked = { x: np.nan for x in FR_COLS}
    
    # reconstructed clear vs reconstructed attacked
    if img_dict['rec_def_attacked'] is not None and img_dict['rec_def_clear'] is not None:
        fr_rec_def_clear_rec_def_attacked = evaluate_fr_metrics(img_dict['rec_def_clear'], img_dict['rec_def_attacked'], device=device)
    else:
        fr_rec_def_clear_rec_def_attacked = { x: np.nan for x in FR_COLS}
    
    # reconstructed clear vs reconstructed attacked (WITHOUT defence)
    if img_dict['rec_undef_attacked'] is not None and img_dict['rec_undef_clear'] is not None:
    
        fr_rec_clear_rec_attacked = evaluate_fr_metrics(img_dict['rec_undef_clear'], img_dict['rec_undef_attacked'], device=device)
    else:
        fr_rec_clear_rec_attacked = { x: np.nan for x in FR_COLS}

    # Calculate NR metrics
    nr_clear = evaluate_nr_metrics(img_dict['clear'], nr_models, device=device)
    nr_attacked = evaluate_nr_metrics(img_dict['attacked'], nr_models, device=device)

    if img_dict['rec_undef_attacked'] is not None:
        nr_rec_undefended_attacked = evaluate_nr_metrics(img_dict['rec_undef_attacked'], nr_models, device=device)
    else:
        nr_rec_undefended_attacked = { x: np.nan for x in NR_COLS}
    if img_dict['rec_undef_clear'] is not None:
        nr_rec_undefended = evaluate_nr_metrics(img_dict['rec_undef_clear'], nr_models, device=device)
    else:
        nr_rec_undefended = { x: np.nan for x in NR_COLS}
    if img_dict['rec_def_clear'] is not None:
        nr_rec_defended = evaluate_nr_metrics(img_dict['rec_def_clear'] , nr_models, device=device)
    else:
        nr_rec_defended = { x: np.nan for x in NR_COLS}
    if img_dict['rec_def_attacked']  is not None:
        nr_rec_defended_attacked = evaluate_nr_metrics(img_dict['rec_def_attacked'], nr_models, device=device)
    else:
        nr_rec_defended_attacked = { x: np.nan for x in NR_COLS}

    iqa_res = {}
    for col in FR_COLS:
        iqa_res[f'{col}_clear_defended-rec-clear'] = np.nan if np.isnan(fr_clear_rec_def_clear[col]).all() else np.nanmean(fr_clear_rec_def_clear[col])
        iqa_res[f'{col}_attacked_defended-rec-attacked'] = np.nan if np.isnan(fr_attacked_rec_def_attacked[col]).all() else np.nanmean(fr_attacked_rec_def_attacked[col])
        iqa_res[f'{col}_clear_attacked'] = np.nan if np.isnan(fr_clear_attacked[col]).all() else np.nanmean(fr_clear_attacked[col])
        iqa_res[f'{col}_clear_rec-clear'] = np.nan if np.isnan(fr_clear_rec_clear[col]).all() else np.nanmean(fr_clear_rec_clear[col])
        iqa_res[f'{col}_attacked_rec-attacked'] = np.nan if np.isnan(fr_attacked_rec_attacked[col]).all() else np.nanmean(fr_attacked_rec_attacked[col])    

        iqa_res[f'{col}_rec-clear_rec-attacked'] = np.mean(fr_rec_clear_rec_attacked[col])     
        iqa_res[f'{col}_defended-rec-clear_defended-rec-attacked'] = np.mean(fr_rec_def_clear_rec_def_attacked[col])

    for col in NR_COLS:
        iqa_res[f'{col}_clear'] = np.nan if np.isnan(nr_clear[col]).all() else np.nanmean(nr_clear[col])
        iqa_res[f'{col}_attacked'] = np.nan if np.isnan(nr_attacked[col]).all() else np.nanmean(nr_attacked[col])
        iqa_res[f'{col}_rec-attacked'] = np.nan if np.isnan(nr_rec_undefended_attacked[col]).all() else np.nanmean(nr_rec_undefended_attacked[col])
        iqa_res[f'{col}_rec-clear'] = np.nan if np.isnan(nr_rec_undefended[col]).all() else np.nanmean(nr_rec_undefended[col])
        iqa_res[f'{col}_defended-rec-clear'] = np.nan if np.isnan(nr_rec_defended[col]).all() else np.nanmean(nr_rec_defended[col])      
        iqa_res[f'{col}_defended-rec-attacked'] = np.nan if np.isnan(nr_rec_defended_attacked[col]).all() else np.nanmean(nr_rec_defended_attacked[col]) 

    return iqa_res 


from traditional_reference_codec import jpeg2k_compress, jpeg2k_compress_fix_bpp

def evaluate_reference_codec(img_dict, target_quality, nr_models, target_bpp, device, dump_path):
    """Evaluates a reference codec (JPEG2000) against the provided images.

    Compresses clear and attacked images using JPEG2000 to a target quality and a fixed target BPP,
    then calculates quality metrics for these compressed versions.

    Args:
        img_dict (dict): Dictionary with 'clear' and 'attacked' image tensors.
        target_quality (list): The target quality for JPEG2000 compression (e.g., target PSNR).
        nr_models (dict): A dictionary of no-reference metric models.
        target_bpp (float): The target bits-per-pixel for fixed-rate compression.
        device (str): The device for calculations.
        dump_path (str): Path to a directory for temporary files.

    Returns:
        tuple: A tuple containing:
            - dict: A dictionary of IQA and BPP results for the reference codec.
            - dict: A dictionary with the reconstructed images from the reference codec.
    """
    img_dict = prepare_imgs(img_dict, device)
    rec_clear_jpeg, jpeg_clear_bpp = jpeg2k_compress(img_dict['clear'], dump_path, target_quality, device)
    rec_attacked_jpeg, jpeg_attacked_bpp = jpeg2k_compress(img_dict['attacked'], dump_path, target_quality, device)
    jpeg_clear_bpp = jpeg_clear_bpp[0]
    jpeg_attacked_bpp = jpeg_attacked_bpp[0]

    fr_clear_rec_clear_jpeg = evaluate_fr_metrics(img_dict['clear'], rec_clear_jpeg, device=device)
    fr_attacked_rec_attacked_jpeg = evaluate_fr_metrics(img_dict['attacked'], rec_attacked_jpeg, device=device)
    fr_rec_clear_jpeg_rec_attacked_jpeg = evaluate_fr_metrics(rec_clear_jpeg, rec_attacked_jpeg, device=device)

    nr_rec_attacked_jpeg = evaluate_nr_metrics(rec_attacked_jpeg, nr_models, device=device)
    nr_rec_clear_jpeg = evaluate_nr_metrics(rec_clear_jpeg, nr_models, device=device)


    rec_clear_jpeg_fix, jpeg_clear_bpp_fix = jpeg2k_compress_fix_bpp(img_dict['clear'], dump_path, [target_bpp], device)
    rec_attacked_jpeg_fix, jpeg_attacked_bpp_fix = jpeg2k_compress_fix_bpp(img_dict['attacked'], dump_path, [target_bpp], device)
    jpeg_clear_bpp_fix = jpeg_clear_bpp_fix[0]
    jpeg_attacked_bpp_fix = jpeg_attacked_bpp_fix[0]

    fr_clear_rec_clear_jpeg_fix = evaluate_fr_metrics(img_dict['clear'], rec_clear_jpeg_fix, device=device)
    fr_attacked_rec_attacked_jpeg_fix = evaluate_fr_metrics(img_dict['attacked'], rec_attacked_jpeg_fix, device=device)
    fr_rec_clear_jpeg_rec_attacked_jpeg_fix = evaluate_fr_metrics(rec_clear_jpeg_fix, rec_attacked_jpeg_fix, device=device)
    
    nr_rec_attacked_jpeg_fix = evaluate_nr_metrics(rec_attacked_jpeg_fix, nr_models, device=device)
    nr_rec_clear_jpeg_fix = evaluate_nr_metrics(rec_clear_jpeg_fix, nr_models, device=device)

    ref_codec_res = {
        'bpp_jpeg-clear':float(jpeg_clear_bpp),
        'bpp_jpeg-attacked':float(jpeg_attacked_bpp),
        'bpp_jpeg-clear-fix':float(jpeg_clear_bpp_fix),
        'bpp_jpeg-attacked-fix':float(jpeg_attacked_bpp_fix),
    }
    for col in FR_COLS:
        ref_codec_res[f'{col}_clear_rec-clear_jpeg'] = np.mean(fr_clear_rec_clear_jpeg[col])
        ref_codec_res[f'{col}_attacked_rec-attacked_jpeg'] = np.mean(fr_attacked_rec_attacked_jpeg[col])
        ref_codec_res[f'{col}_rec-clear_rec-attacked_jpeg'] = np.mean(fr_rec_clear_jpeg_rec_attacked_jpeg[col])

        ref_codec_res[f'{col}_clear_rec-clear_jpeg-fix'] = np.mean(fr_clear_rec_clear_jpeg_fix[col])
        ref_codec_res[f'{col}_attacked_rec-attacked_jpeg-fix'] = np.mean(fr_attacked_rec_attacked_jpeg_fix[col])
        ref_codec_res[f'{col}_rec-clear_rec-attacked_jpeg-fix'] = np.mean(fr_rec_clear_jpeg_rec_attacked_jpeg_fix[col])

    for col in NR_COLS:
        ref_codec_res[f'{col}_rec-attacked_jpeg'] = np.nan if np.isnan(nr_rec_attacked_jpeg[col]).all() else np.nanmean(nr_rec_attacked_jpeg[col])
        ref_codec_res[f'{col}_rec-clear_jpeg'] = np.nan if np.isnan(nr_rec_clear_jpeg[col]).all() else np.nanmean(nr_rec_clear_jpeg[col])

        ref_codec_res[f'{col}_rec-attacked_jpeg-fix'] = np.nan if np.isnan(nr_rec_attacked_jpeg_fix[col]).all() else np.nanmean(nr_rec_attacked_jpeg_fix[col])
        ref_codec_res[f'{col}_rec-clear_jpeg-fix'] = np.nan if np.isnan(nr_rec_clear_jpeg_fix[col]).all() else np.nanmean(nr_rec_clear_jpeg_fix[col])

    return ref_codec_res, {'ref_codec_attacked': rec_attacked_jpeg, 'ref_codec_fix_bpp_attacked': rec_attacked_jpeg_fix}

