"""Frechet Wavelet Distance computation."""

import os
import pathlib
from typing import List, Tuple

import numpy as np
import torch as th
import torchvision.transforms as tv
from tqdm import tqdm

from src.freq_math import calculate_frechet_distance, forward_wavelet_packet_transform
from src.utils import ImagePathDataset, _parse_args

th.set_default_dtype(th.float64)


IMAGE_EXTS = {"jpg", "jpeg", "png"}
NUM_PROCESSES = None


def compute_packet_statistics(
    dataloader: th.utils.data.DataLoader, wavelet: str, max_level: int, log_scale: bool
) -> Tuple[np.ndarray, ...]:
    """Compute wavelet packet transform across batches.

    Args:
        dataloader (th.utils.data.DataLoader): Torch dataloader.
        wavelet (str): Choice of wavelet.
        max_level (int): Wavelet decomposition level.
        log_scale (bool): Apply log scale.

    Returns:
        Tuple[np.ndarray, ...]: Mean and sigma for each packet.
    """
    packets = []
    device = th.device("cuda:0") if th.cuda.is_available() else th.device("cpu")
    for img_batch in tqdm(dataloader):
        if isinstance(img_batch, list):
            img_batch = img_batch[0]
        img_batch = img_batch.to(device)
        packets.append(
            forward_wavelet_packet_transform(
                img_batch, wavelet, max_level, log_scale
            ).cpu()
        )
    packet_tensor = th.cat(packets, dim=0)
    packet_tensor = th.permute(packet_tensor, (1, 0, 2, 3, 4))
    P, BS, C, H, W = packet_tensor.shape
    packet_tensor = th.reshape(packet_tensor, (P, BS, C * H * W))
    print("Computing mean and std for each packet.")
    mu = th.mean(packet_tensor, dim=1).numpy()
    sigma = th.stack(
        [th.cov(packet_tensor[p, :, :].T) for p in range(P)], dim=0
    ).numpy()
    return mu, sigma


def calculate_path_statistics(
    path: str, wavelet: str, max_level: int, log_scale: bool, batch_size: int
) -> Tuple[np.ndarray, ...]:
    """Compute mean and sigma for given path.

    Args:
        path (str): npz path or image directory.
        wavelet (str): Choice of wavelet.
        max_level (int): Decomposition level.
        log_scale (bool): Apply log scale.
        batch_size (int): Batch size for packet decomposition.

    Raises:
        ValueError: Error if mu and sigma cannot be calculated.

    Returns:
        Tuple[np.ndarray, ...]: Tuple containing mean and sigma for each packet.
    """
    mu, sigma = None, None
    if path.endswith(".npz"):
        with np.load(path) as fp:
            mu = fp["mu"][:]
            sigma = fp["sigma"][:]
    else:
        posfix_path = pathlib.Path(path)
        img_names = sorted(
            [name for ext in IMAGE_EXTS for name in posfix_path.glob(f"*.{ext}")]
        )
        dataloader = th.utils.data.DataLoader(
            ImagePathDataset(img_names, transforms=tv.ToTensor()),
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=NUM_PROCESSES,
        )
        mu, sigma = compute_packet_statistics(
            dataloader=dataloader,
            wavelet=wavelet,
            max_level=max_level,
            log_scale=log_scale,
        )

    if (mu is None) or (sigma is None):
        raise ValueError(f"The file path: {path} is empty/doesn't have statistics.")
    return mu, sigma


def compute_avg_frechet_distance(mu1, mu2, sigma1, sigma2):
    """Compute avg frechet distance over packets."""
    frechet_distances = []
    for packet_no in tqdm(range(len(mu1))):
        fd = calculate_frechet_distance(
            mu1=mu1[packet_no, :],
            mu2=mu2[packet_no, :],
            sigma1=sigma1[packet_no, :, :],
            sigma2=sigma2[packet_no, :, :],
        )
        frechet_distances.append(fd)
    return np.mean(frechet_distances)


def compute_fwd(
    paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int
) -> float:
    """Compute Frechet Wavelet Distance.

    Args:
        paths (List[str]): List containing path of source and generated images.
        wavelet (str): Choice of wavelet.
        max_level (int): Decomposition level.
        log_scale (bool): Apply log scale.
        batch_size (int): Batch size for packet decomposition.

    Raises:
        RuntimeError: Error if path doesn't exist.

    Returns:
        float: Frechet Wavelet Distance.
    """
    for path in paths:
        if not os.path.exists(path):
            raise RuntimeError(f"Invalid path: {path}")

    print(f"Computing stats for path: {paths[0]}")
    mu_1, sigma_1 = calculate_path_statistics(
        paths[0], wavelet, max_level, log_scale, batch_size
    )
    print(f"Computing stats for path: {paths[1]}")
    mu_2, sigma_2 = calculate_path_statistics(
        paths[1], wavelet, max_level, log_scale, batch_size
    )

    print("Computing Frechet distances for each packet.")
    return compute_avg_frechet_distance(mu_1, mu_2, sigma_1, sigma_2)


def save_packets(
    paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int
) -> None:
    """Save packets.

    Args:
        paths (List[str]): List of paths containing input and output files.
        wavelet (str): Choice of wavelet.
        max_level (int): Decomposition level.
        log_scale (bool): Apply log scale.
        batch_size (int): Batch size for packet decomposition.

    Raises:
        RuntimeError: Error if input path is invalid.
        RuntimeError: Error if the output file already exists.
    """
    if not os.path.exists(paths[0]):
        raise RuntimeError(f"Invalid path: {paths[0]}")

    if os.path.exists(paths[1]):
        raise RuntimeError(f"Stats file already exists at the given path: {paths[1]}")

    print(f"Computing stats for path: {paths[0]}")
    mu_1, sigma_1 = calculate_path_statistics(
        paths[0], wavelet, max_level, log_scale, batch_size
    )
    np.savez_compressed(paths[1], mu=mu_1, sigma=sigma_1)


def main():
    """Compute FWD given paths."""
    global NUM_PROCESSES, IMAGE_EXTS

    th.manual_seed(0)
    th.use_deterministic_algorithms(True)
    args = _parse_args()
    print(args)
    if args.num_processes is None:
        try:
            num_cpus = len(os.sched_getaffinity(0))
        except AttributeError:
            num_cpus = os.cpu_count()
        NUM_PROCESSES = min(num_cpus, 16) if num_cpus is not None else 0
    else:
        NUM_PROCESSES = args.num_processes
    print(f"Num work: {NUM_PROCESSES}")
    if args.save_packets:
        save_packets(
            args.path, args.wavelet, args.max_level, args.log_scale, args.batch_size
        )
        return

    fwd = compute_fwd(
        args.path, args.wavelet, args.max_level, args.log_scale, args.batch_size
    )
    print(f"FWD: {fwd}")


if __name__ == "__main__":
    main()
