from util.logger import logger

from typing import Optional, Tuple, Union, List

import torch

from PIL import Image

from tqdm.auto import tqdm

import gc

from util.image_util import img_pil_to_tensor

from lpips import LPIPS


def cal_lpips(
    img_pil_list_1: Union[Image.Image, List[Image.Image]], 
    img_pil_list_2: Union[Image.Image, List[Image.Image]], 
    img_size: Optional[Union[int, Tuple[int, int]]] = None, 

    # ---------= [LPIPS] =---------
    lpips_net_type: Optional[str] = "alex",  # ["alex", "vgg"]
    batch_size: Optional[int] = 1, 

    device: Optional[str] = "cpu", 

    model: Optional["nn.Module"] = None, 

    disable_tqdm: Optional[bool] = False
) -> Union[
    List[float], 
    "nn.Module"
]:
    if not isinstance(img_pil_list_1, list):
        img_pil_list_1 = [img_pil_list_1]
    if not isinstance(img_pil_list_2, list):
        img_pil_list_2 = [img_pil_list_2]

    num_img = len(img_pil_list_1)

    if img_size is None:
        width, height = img_pil_list_1[0].size
        img_size = (height, width)

    if model is None:
        model = LPIPS(net = lpips_net_type) \
            .to(device)
    
    img_tensor_list_1 = [
        img_pil_to_tensor(
            img_pil = img_pil, 
            img_size = img_size
        ) \
            for img_pil in img_pil_list_1
    ]
    img_tensor_list_2 = [
        img_pil_to_tensor(
            img_pil = img_pil, 
            img_size = img_size
        ) \
            for img_pil in img_pil_list_2
    ]

    half_batch_size = batch_size // 2

    lpips_score_list = []

    for i in tqdm(
        range(0, num_img, half_batch_size), 

        desc = f"[Compute LPIPS]", 

        disable = disable_tqdm
    ):
        # ---------= [Prepare Batch Image Tensor List] =---------
        batch_img_tensor_list_1 = [
            img_tensor_list_1[j].to(device) \
                for j in range(i, min(i + half_batch_size, num_img))
        ]

        batch_img_tensor_list_2 = [
            img_tensor_list_2[j].to(device) \
                for j in range(i, min(i + half_batch_size, num_img))
        ]

        batch_img_tensor_list_1 = torch.stack(batch_img_tensor_list_1)
        batch_img_tensor_list_2 = torch.stack(batch_img_tensor_list_2)

        # ---------= [Compute MPD] =---------
        with torch.no_grad():
            batch_lpips_score_list = model(batch_img_tensor_list_1, batch_img_tensor_list_2)
        
        true_batch_size = batch_lpips_score_list.shape[0]
        batch_lpips_score_list = [
            batch_lpips_score_list[i].item() \
                for i in range(true_batch_size)
        ]

        lpips_score_list += batch_lpips_score_list

        # ---------= [Clean Up] =---------
        del batch_img_tensor_list_1, batch_img_tensor_list_2
        gc.collect()
        torch.cuda.empty_cache()
        
        # goto `for i`
        pass

    # ---------= [Clean Up] =---------
    del img_tensor_list_1, img_tensor_list_2
    gc.collect()
    torch.cuda.empty_cache()

    # `cal_lpips()` done
    return (
        lpips_score_list, 
        model
    )


def cal_mean_pairwise_distance_lpips(
    img_pil_list: List[Image.Image], 
    img_size: Optional[Union[int, Tuple[int, int]]] = None, 

    # ---------= [LPIPS] =---------
    lpips_net_type: Optional[str] = "alex",  # ["alex", "vgg"]
    batch_size: Optional[int] = 1, 

    device: Optional[str] = "cpu", 

    model: Optional["nn.Module"] = None, 

    disable_tqdm: Optional[bool] = False
) -> Union[float, "nn.Module"]:
    if img_size is None:
        width, height = img_pil_list[0].size
        img_size = (height, width)

    if model is None:
        model = LPIPS(net = lpips_net_type) \
            .to(device)
    
    img_tensor_list = [
        img_pil_to_tensor(
            img_pil = img_pil, 
            img_size = img_size
        ) \
            for img_pil in img_pil_list
    ]

    num_img = len(img_tensor_list)

    idx_pair_list = [
        (i, j) \
            for i in range(num_img) \
                for j in range(i + 1, num_img)
    ]
    
    num_pair = num_img * (num_img - 1) // 2

    sum_lpips_score = 0.0

    for idx_st in tqdm(
        range(0, num_pair, batch_size), 

        desc = f"[Compute MPD]", 

        disable = disable_tqdm
    ):
        idx_ed_plus_one = min(idx_st + batch_size, num_pair)

        batch_img_tensor_list_1 = [
            img_tensor_list[i].to(device) \
                for (i, j) in idx_pair_list[idx_st: idx_ed_plus_one]
        ]
        batch_img_tensor_list_2 = [
            img_tensor_list[j].to(device) \
                for (i, j) in idx_pair_list[idx_st: idx_ed_plus_one]
        ]

        batch_img_tensor_list_1 = torch.stack(batch_img_tensor_list_1)
        batch_img_tensor_list_2 = torch.stack(batch_img_tensor_list_2)

        with torch.no_grad():
            tmp_lpips_score = model(batch_img_tensor_list_1, batch_img_tensor_list_2)

        tmp_lpips_score = tmp_lpips_score.sum() \
            .item()
        sum_lpips_score += tmp_lpips_score

        # ---------= [Clean Up] =---------
        del batch_img_tensor_list_1, batch_img_tensor_list_2
        del tmp_lpips_score
        gc.collect()
        torch.cuda.empty_cache()

        # goto `for idx_st`
        pass
    
    if num_pair > 0:
        lpips_score = sum_lpips_score / num_pair

    return (
        lpips_score, 
        model
    )
    