import cv2
from pathlib import Path
import os
import torch
import numpy as np
from color_transforms_255 import ycbcr_to_rgb_255
import json 
import time

def to_torch(x, device='cpu'):
    """Converts a NumPy array to a PyTorch tensor.

    Args:
        x (np.ndarray): Input NumPy array. Can be of shape (H, W, C) or (N, H, W, C).
        device (str, optional): The device to move the tensor to. Defaults to 'cpu'.

    Returns:
        torch.Tensor: The converted PyTorch tensor with shape (N, C, H, W).
    """
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
        if len(x.shape) == 3:
            x = x.permute(2, 0, 1).unsqueeze(0)
        else:
            x = x.permute(0, 3, 1, 2)
        x = x.type(torch.FloatTensor).to(device)
    return x

def to_numpy(x):
    """Converts a PyTorch tensor to a NumPy array.

    Args:
        x (torch.Tensor): Input PyTorch tensor of shape (N, C, H, W).

    Returns:
        np.ndarray: The converted NumPy array with shape (N, H, W, C).
    """
    if torch.is_tensor(x):
        x = x.cpu().detach().permute(0, 2, 3, 1).numpy()
    return x if len(x.shape) == 4 else x[np.newaxis]

def center_crop(image):
  """Crops the center of an image to 256x256.

  If the image is smaller than 256x256, it is resized to 256x256.

  Args:
      image (np.ndarray): The input image.

  Returns:
      np.ndarray: The center-cropped or resized image.
  """
  center = image.shape[0] / 2, image.shape[1] / 2
  if center[1] < 256 or center[0] < 256:
    return cv2.resize(image, (256, 256))
  x = center[1] - 128
  y = center[0] - 128

  return image[int(y):int(y+256), int(x):int(x+256)]


def _maincodec_output_2_rgb_01(x, output_range):
    """Converts main codec output to RGB [0,1] range.

    Args:
        x (torch.Tensor): The output tensor from the main codec.
        output_range (float): The maximum value of the output tensor's range (e.g., 255).

    Returns:
        torch.Tensor: The converted tensor, clamped to [0, 1].
    """
    return torch.clamp(x / output_range, 0,1) if x is not None else None
def ycbcr255_to_rgb_01(x, output_range):
    """Converts a YCbCr tensor [0, 255] to an RGB tensor [0, 1].

    Args:
        x (torch.Tensor): The input YCbCr tensor.
        output_range (float): The maximum value of the input tensor's range (e.g., 255).

    Returns:
        torch.Tensor: The converted RGB tensor, clamped to [0, 1].
    """
    return torch.clamp(ycbcr_to_rgb_255(x) / output_range, 0,1)

def save_image(img, dump_path, img_name, img_type):
    """Saves a tensor as a PNG image.

    Args:
        img (torch.Tensor): The image tensor to save, expected in [0,1] RGB format.
        dump_path (str): The directory where the image will be saved.
        img_name (str): The base name for the image file.
        img_type (str): A prefix to distinguish the image type (e.g., 'attacked').
    """
    if len(img.shape) == 4:
        img = img[0]
    cv2.imwrite(os.path.join(dump_path, f'{img_type}_{img_name}.png'), cv2.cvtColor(to_numpy(torch.clamp(img,0,1).unsqueeze(0)).squeeze(0) * 255, cv2.COLOR_RGB2BGR))


def apply_attack(model, attack_callback, dist_images, device='cpu', variable_params={}, seed=42, is_jpegai=False, loss_func=None, loss_func_name='undefined'):
    """Applies an adversarial attack to a batch of images.

    Args:
        model (torch.nn.Module): The model to be attacked.
        attack_callback (function): The function that performs the attack.
        dist_images (torch.Tensor): The input images to attack.
        device (str, optional): The device to run the attack on. Defaults to 'cpu'.
        variable_params (dict, optional): Attack-specific parameters. Defaults to {}.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.
        is_jpegai (bool, optional): Flag for JPEGAI models. Defaults to False.
        loss_func (function, optional): The loss function for the attack. Defaults to None.
        loss_func_name (str, optional): Name of the loss function. Defaults to 'undefined'.

    Returns:
        tuple: A tuple containing:
            - torch.Tensor: The attacked images.
            - float: The time taken for the attack.
        Returns None if the attack fails.
    """
    model.train()
    torch.manual_seed(seed)
    t0 = time.time()
    attacked_images = attack_callback(dist_images.clone(), model=model, device=device, is_jpegai=is_jpegai, loss_func=loss_func, loss_func_name=loss_func_name, **variable_params)

    attack_time = time.time() - t0
    model.eval()
    if attacked_images is None:
        return None
    return attacked_images, attack_time

def apply_codec(img, model, is_main, is_jpegai, fn, torch_seed, device, output_range, mainc_save_name='test'):
    """Applies a codec to an image and collects statistics.

    Args:
        img (torch.Tensor): The input image tensor.
        model (torch.nn.Module): The codec model.
        is_main (bool): Flag indicating if it's the main codec from JPEGAI.
        is_jpegai (bool): Flag indicating if the model is a JPEGAI model.
        fn (str): The original filename of the image.
        torch_seed (int): Random seed for reproducibility.
        device (str): The device to run the codec on.
        output_range (float): The output range of the codec (e.g., 255).
        mainc_save_name (str, optional): Suffix for saved bitstream/reconstruction files. Defaults to 'test'.

    Returns:
        dict: A dictionary containing the reconstructed image, bpp, real_bpp, and codec_time.
    """
    res = {}
    img = img.to(device)
    time_st = time.time()
    torch.manual_seed(torch_seed)
    if is_main:
        outs = model.forward(img, bits_pathes=[f'./{Path(fn).stem}_{mainc_save_name}.bits'], rec_pathes=[f'./{Path(fn).stem}_{mainc_save_name}.png'] )
        if outs['x_hat'] is None:
            print(f'[Warning] JPEGAI Main codec failure. fn={fn}, save_name={mainc_save_name}')
    else:
        outs = model(img, return_bpp=True)
    delta_time = time.time() - time_st
    rec_imgs = outs['x_hat']
    if rec_imgs is not None:
        rec_imgs = torch.clamp(rec_imgs, 0, output_range)
        if is_main:
            rec_imgs = _maincodec_output_2_rgb_01(rec_imgs, output_range)
        if is_jpegai and not is_main:
            rec_imgs = ycbcr255_to_rgb_01(rec_imgs, output_range)
    res['rec_img'] = rec_imgs
    res['bpp'] = float(outs['bpp'])
    res['real_bpp'] = float(outs['real_bpp']) if 'real_bpp' in outs.keys() else np.nan
    res['codec_time'] = delta_time
    return res 

def load_defence_params_json(preset_name, presets_path='defence/defence_presets.json'):
    """Loads defence parameters from a JSON file based on a preset name.

    Args:
        preset_name (str or int): The name or index of the preset to load.
        presets_path (str, optional): Path to the JSON file with presets. Defaults to 'defence/defence_presets.json'.

    Returns:
        dict: A dictionary of defence parameters. Returns an empty dict if preset_name is -1.
    """
    res = {}
    if preset_name != -1:        
        with open(presets_path) as json_file:
            presets = json.load(json_file)
            for param_name in presets:
                res[param_name] = presets[param_name]['presets'][f'{preset_name}']
        return res
    else:
        print(f'[Warning] defence: Preset == -1 was passed: ignoring presets, using global default params')
        return {}

def load_attack_params_json(preset_name, attack_name, presets_path='attack_presets_codecs.json'):
    """Loads attack parameters from a JSON file for a specific attack and preset.

    Args:
        preset_name (str or int): The name or index of the preset to load.
        attack_name (str): The name of the attack.
        presets_path (str, optional): Path to the JSON file with presets. Defaults to 'attack_presets_codecs.json'.

    Returns:
        dict: A dictionary of attack parameters. Returns an empty dict if preset_name is -1.
    """
    if preset_name != -1:        
        with open(presets_path) as json_file:
            presets = json.load(json_file)
            cur_preset = presets[attack_name][int(preset_name)]
        return cur_preset
    else:
        print(f'[Warning] defence: Preset == -1 was passed: ignoring presets, using global default params')
        return {}
    

def fill_df_metadata(df, dataset_name, atk_name, atk_preset_name, defence_preset_name):
    """
    Fills the metadata columns of the DataFrame with the provided values.

    Args:
        df (pd.DataFrame): The DataFrame to fill.
        dataset_name (str): The name of the dataset.
        atk_name (str): The name of the attack.
        atk_preset_name (str or int): The name or index of the attack preset.
        defence_preset_name (int): The index of the defence preset.

    Returns:
        pd.DataFrame: The filled DataFrame.
    """
    filled_df = df.copy()
    filled_df['test_dataset'] = dataset_name
    filled_df['attack'] = atk_name
    filled_df['preset'] = atk_preset_name
    filled_df['attack_preset'] = atk_preset_name
    filled_df['defence_preset'] = defence_preset_name
    filled_df['defence_preset'] = filled_df['defence_preset'].astype(int)
    if isinstance(atk_preset_name, int):
        filled_df['attack_preset'] = filled_df['attack_preset'].astype(int)
    return filled_df