"""
Reconstruct images from original and perturbed VAE encodings.
Use the decoder to reconstruct images from the encodings.
"""

from __future__ import annotations

import argparse

import math
import os
import json
from typing import Any

import numpy as np
import torch
import tqdm
import yaml
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage
from diffusers.models import AutoencoderKL

from metamer.icnn_replication.evaluation import (
    cosine_distance, correlation_distance, l2_distance
)


EPSILON = 1e-8  # small value to avoid division by zero in calculations


def get_latent(model: torch.nn.Module, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Get the latent mean and logvar from the VAE model for a batch of images.

    Args:
        images (torch.Tensor): Batch of images. Shape (B, C, H, W).
    """
    encode = model.encode(images)
    dist = encode.latent_dist
    return dist.mean, dist.logvar


def load_images(config: dict[str, Any], image_names: list[str], device: str) -> torch.Tensor:
    totensor = ToTensor()
    size = config['image_size']

    image_dir = config['data']['image_dir']
    image_ext = config['data']['image_ext']
    images = []
    for name in image_names:
        path = os.path.join(image_dir, f"{name}{image_ext}")
        image = Image.open(path).convert("RGB").resize((size, size))
        image = totensor(image)
        images.append(image)
    return torch.stack(images).to(device)


def _corr_distance(a: torch.Tensor, b: torch.Tensor) -> float:
    """
    Correlation distance between two tensors.
    """
    a = a.flatten()
    b = b.flatten()
    a = a - a.mean()
    b = b - b.mean()
    cov = (a * b).mean()
    denom = a.std(unbiased=False) * b.std(unbiased=False) + EPSILON
    return 1.0 - (cov / denom).item()


def add_noise(
    x: torch.Tensor,
    d_c: float,
    tol: float = 1e-3,
    max_iter: int = 30,
    seed: int = None
) -> torch.Tensor:
    """
    Add Gaussian noise \epsilon so that corr-distance(x, x+\epsilon) ≈ d_c within `tol`.
    Uses a fixed noise vector and bisection on its scale.

    Parameters
    ----------
    x : torch.Tensor
        Original tensor.
    d_c : float
        Desired correlation distance (0 ≤ d_c < 1).
    tol : float, optional
        Allowed absolute error |d_actual - d_c|.
    max_iter : int, optional
        Maximum iterations of the root-finder (bisection + bracketing).
    seed: int, optional
        Random seed for reproducibility. If None, uses the current random state.

    Returns
    -------
    torch.Tensor
        x + \epsilon whose sample correlation distance is within `tol`
        (or the nearest value reached in `max_iter` iterations).
    float
        Actual correlation distance of the returned tensor.
    float
        Scale of the noise vector used to perturb the original tensor.
    """
    if seed is not None:
        torch.manual_seed(seed)
    assert 0.0 <= d_c < 1.0, "0 <= d_c < 1 required."
    assert tol > 0.0, "tol must be positive."
    
    if d_c <= tol:  # no noise needed
        return x.clone(), 0.0, 0.0

    # initial guess
    var_x = x.var(unbiased=False)
    target_r = 1.0 - d_c                       # desired correlation
    initial_std = math.sqrt(var_x) * math.sqrt(1.0 / target_r**2 - 1.0)

    noise = torch.randn_like(x)
    def dist(std: float) -> float:
        return _corr_distance(x, x + noise * std)

    # --- bracket the solution ------------------------------------------------
    lo, hi = 0.0, initial_std
    if dist(hi) < d_c:                              # initial guess too small
        for _ in range(max_iter):
            lo = hi
            hi *= 2.0
            d_hi = dist(hi)
            if d_hi >= d_c:
                break

    # --- bisection -----------------------------------------------------------
    for _ in range(max_iter):
        mid = 0.5 * (lo + hi)
        d_mid = dist(mid)
        if abs(d_mid - d_c) <= tol:             # tolerance satisfied
            return x + noise * mid, d_mid, mid
        if d_mid < d_c:                         # need more noise
            lo = mid
        else:                                   # need less noise
            hi = mid

    # max_iter reached – return best available
    return x + noise * mid, d_mid, mid


def get_noised_latent(mean, logvar, target_corr, seed, device, noise_type='normal'):
    noised = torch.empty_like(mean)
    noised = noised.to(device)
    
    if noise_type == 'normal':
        batch = mean.shape[0]
        for i in range(batch):
            noised[i], _, _ = add_noise(mean[i], target_corr, seed=seed)
    else:
        raise ValueError(f"Unknown noise type: {noise_type}")
    return noised


def eval_images(images: torch.Tensor, recon_images: torch.Tensor) -> list[dict[str, float]]:
    # name -> list of metric values
    results = {
        'pixel_correlation_distance': correlation_distance(images, recon_images),
        'pixel_cosine_distance': cosine_distance(images, recon_images),
        'pixel_l2_distance': l2_distance(images, recon_images),
    }
    results = {k: v.tolist() for k, v in results.items()}  # convert to list for JSON serialization

    # list of dicts for each sample
    results = [
        {k: v[i] for k, v in results.items()}
        for i in range(len(images))
    ]
    return results


def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    return config['model']['pretrained'].replace('/', '_')


def parse_output_dir(config: dict[str, Any]) -> str:
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = parse_model_name(config)
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_vae', 'results', model_name, dataset_name, exp_name)


def main(config: dict[str, Any], device: str) -> None:
    image_names = config['data']['image_names']
    n_samples = len(image_names)
    batch_size = config['batch_size']
    n_iter = int(np.ceil(n_samples / batch_size))
    
    output_dir = parse_output_dir(config)
    os.makedirs(output_dir, exist_ok=True)

    model = AutoencoderKL.from_pretrained(config['model']['pretrained'])
    model.to(device)

    # helper functions
    topil = ToPILImage()

    # results (list[dict])
    results = []

    with torch.no_grad():
        for i in tqdm.tqdm(range(n_iter), total=n_iter):
            # obtain batch of images
            names = image_names[i * batch_size:(i + 1) * batch_size]
            images = load_images(config, names, device)

            # normalize images into [-1, 1]
            normalized_images = (images.clamp(0, 1) - 0.5) * 2.0

            # encodings
            mean, logvar = get_latent(model, normalized_images)

            for noise_seed in config['noise']['noise_seeds']:

                for target_corr in config['noise']['target_corr_dists']:
                    # add noise to the encodings
                    noised_latents = get_noised_latent(mean, logvar, target_corr, noise_seed, device, config['noise']['type'])

                    # decode
                    recon_images = model.decode(noised_latents).sample

                    # scale back to [0, 1]
                    recon_images = (recon_images.clamp(-1, 1) + 1) * 0.5

                    # evaluate images
                    eval_metrics = eval_images(images, recon_images)

                    # store results
                    for j, name in enumerate(names):
                        r = eval_metrics[j]
                        r['image_name'] = name
                        r['target_corr'] = target_corr
                        r['noise_seed'] = noise_seed
                        results.append(r)

                    # save reconstructed images as png
                    for j, name in enumerate(names):
                        image_save_dir = os.path.join(output_dir, f'corr_dist_{target_corr}', f'noise_seed_{noise_seed}')
                        os.makedirs(image_save_dir, exist_ok=True)
                        image_save_path = os.path.join(image_save_dir, f'{name}.png')
                        recon_image = topil(recon_images[j].cpu())
                        recon_image.save(image_save_path)

    # save results
    results_save_path = os.path.join(output_dir, 'experiment_db.json')
    with open(results_save_path, 'w') as f:
        json.dump(results, f, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    parser.add_argument("--device", type=str, default='cuda')
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    main(config, device=args.device)