"""
Reconstruct images from the original and perturbed features.
Please refer to the README.md file in the same directory for the usage.
"""
from __future__ import annotations

import argparse
import json
import os
from typing import Any, Optional
import traceback
import math

import numpy as np
import yaml
import pandas as pd
import torch
import wandb
from torchvision.transforms import Compose, Resize, ToTensor
from PIL import Image

from metamer.icnn_replication import critic, image_domain
from metamer.icnn_replication.pipeline import FeatureInversionPipeline
from metamer.icnn_replication.snapshot import IntervalSnapshotSaver
from metamer.reconstruct.experiment_db import JSONExperimentDB
from metamer.reconstruct.models import load_generator, load_encoder_class, load_model_and_domain
from metamer.reconstruct.recon_common import (
    load_critic, load_optimizer_and_scheduler,
    DictFeaturesDataset, DictImageDataset
)
from metamer.icnn_replication.evaluation import (
    cosine_distance, correlation_distance, l2_distance
)

# Default image size for evaluation and generation
# Note that the images are resized to match the input size of the model.
IMG_SIZE = (224, 224)  
DTYPE = torch.float32
EPSILON = 1e-8  # small value to avoid division by zero in calculations


# Feature perturbation
def _corr_distance(a: torch.Tensor, b: torch.Tensor) -> float:
    """
    Correlation distance between two tensors.
    """
    a = a.flatten()
    b = b.flatten()
    a = a - a.mean()
    b = b - b.mean()
    cov = (a * b).mean()
    denom = a.std(unbiased=False) * b.std(unbiased=False) + EPSILON
    return 1.0 - (cov / denom).item()


def add_noise(
    x: torch.Tensor,
    d_c: float,
    tol: float = 1e-3,
    max_iter: int = 30,
    seed: int = None
) -> torch.Tensor:
    """
    Add Gaussian noise \epsilon so that corr-distance(x, x+\epsilon) ≈ d_c within `tol`.
    Uses a fixed noise vector and bisection on its scale.

    Parameters
    ----------
    x : torch.Tensor
        Original tensor.
    d_c : float
        Desired correlation distance (0 ≤ d_c < 1).
    tol : float, optional
        Allowed absolute error |d_actual - d_c|.
    max_iter : int, optional
        Maximum iterations of the root-finder (bisection + bracketing).
    seed: int, optional
        Random seed for reproducibility. If None, uses the current random state.

    Returns
    -------
    torch.Tensor
        x + \epsilon whose sample correlation distance is within `tol`
        (or the nearest value reached in `max_iter` iterations).
    float
        Actual correlation distance of the returned tensor.
    float
        Scale of the noise vector used to perturb the original tensor.
    """
    if seed is not None:
        torch.manual_seed(seed)
    assert 0.0 <= d_c < 1.0, "0 <= d_c < 1 required."
    assert tol > 0.0, "tol must be positive."
    
    if d_c <= tol:  # no noise needed
        return x.clone(), 0.0, 0.0

    # initial guess
    var_x = x.var(unbiased=False)
    target_r = 1.0 - d_c                       # desired correlation
    initial_std = math.sqrt(var_x) * math.sqrt(1.0 / target_r**2 - 1.0)

    noise = torch.randn_like(x)
    def dist(std: float) -> float:
        return _corr_distance(x, x + noise * std)

    # --- bracket the solution ------------------------------------------------
    lo, hi = 0.0, initial_std
    if dist(hi) < d_c:                              # initial guess too small
        for _ in range(max_iter):
            lo = hi
            hi *= 2.0
            d_hi = dist(hi)
            if d_hi >= d_c:
                break

    # --- bisection -----------------------------------------------------------
    for _ in range(max_iter):
        mid = 0.5 * (lo + hi)
        d_mid = dist(mid)
        if abs(d_mid - d_c) <= tol:             # tolerance satisfied
            return x + noise * mid, d_mid, mid
        if d_mid < d_c:                         # need more noise
            lo = mid
        else:                                   # need less noise
            hi = mid

    # max_iter reached – return best available
    return x + noise * mid, d_mid, mid


def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    # if alias is provided, use it
    if config['model'].get('model_alias'):
        return config['model']['model_alias']
    if config['model']['name'].endswith('-tfm'):
        # transformer model: use pretrained name by replacing '/' with '_'
        return config['model']['pretrained'].replace('/', '_')
    # otherwise, use the model name
    return config['model']['name']


def parse_feature_dir(config: dict[str, Any]) -> str:
    if 'feature_dir' in config:
        return config['feature_dir']
    model_name = parse_model_name(config)
    return os.path.join('output', 'readout_vision', 'features', model_name, config['data']['dataset_name']) 


def parse_output_dir(config: dict[str, Any]) -> str:
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = parse_model_name(config)
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_vision', 'results', model_name, dataset_name, exp_name)


def initialize_experiment_db(config: dict[str, Any], output_dir: str) -> JSONExperimentDB:
    experiment_db = JSONExperimentDB(
        exp_db_path=os.path.join(output_dir, 'experiment_db.json'),
        exp_db_lock_path=os.path.join(output_dir, 'experiment_db.lock'),
        param_clms=['layer', 'image_name', 'target_corr_dist', 'noise_seed']
    )
    """
    Add combination of parameters to the experiment database.
    """
    # add experiments to the database
    new_experiments = []
    for layer in config['layers']:
        for image_name in config['data']['image_names']:
            for target_corr_dist in config['noise']['target_corr_dists']:
                # corr dist 0 == no noise == no noise_seeds
                if target_corr_dist == 0.0:
                    new_experiments.append({
                        'layer': layer,
                        'image_name': image_name,
                        'target_corr_dist': target_corr_dist,
                        'noise_seed': None  # no noise
                    })
                else:
                    for noise_seed in config['noise']['noise_seeds']:
                        new_experiments.append({
                            'layer': layer,
                            'image_name': image_name,
                            'target_corr_dist': target_corr_dist,
                            'noise_seed': noise_seed
                        })
    experiment_db.add_experiments(new_experiments)
    return experiment_db


def load_data(
        config: dict[str, Any],
        parameters: list[dict[str, Any]],
        image_dataset: DictImageDataset,
        feature_dataset: DictFeaturesDataset,
        device, dtype
    ):
    """
    Load images, true features, and noised features from the datasets.

    Args:
        parameters (list[dict[str, Any]]): List of parameters for each experiment.
        image_dataset (DictImageDataset): Dataset containing images.
        feature_dataset (DictFeaturesDataset): Dataset containing features.
    Returns:
        images (torch.Tensor): image tensor with shape (batch, 3, h, w).
        true_features (dict[str, torch.Tensor]): Dictionary of true features for the layer. (batch, feature_dim*)
        noised_features (dict[str, torch.Tensor]): Dictionary of noised features for the layer.
        info (dict[str, list]): dictionary containing information about the noise and true-target distances.
    """
    # names of each samples
    names = [p['image_name'] for p in parameters]

    # load images
    images = [image_dataset[name] for name in names]
    images = torch.stack(images).to(device=device, dtype=dtype)

    # load features: assume layer parameters are the same for all experiments
    true_features = []
    noised_features = []
    true_target_corr_dists = []
    noise_stds = []
    for param in parameters:
        name, layer = param['image_name'], param['layer']
        true_f = feature_dataset[name][layer]  # torch.Tensor (feature_dim*)
        noised_f, dc, std = add_noise(true_f, param['target_corr_dist'], tol=config['noise']['tol'], seed=param['noise_seed'])

        true_features.append(true_f)
        noised_features.append(noised_f)
        true_target_corr_dists.append(dc)
        noise_stds.append(std)
    # stack features
    true_features = {layer: torch.stack(true_features).to(device=device, dtype=dtype)}
    noised_features = {layer: torch.stack(noised_features).to(device=device, dtype=dtype)}
    
    # store the information about the noised features
    info = {
        'true_target_correlation_distance': true_target_corr_dists,
        'true_target_cosine_distance': cosine_distance(true_features[layer], noised_features[layer]).tolist(),
        'true_target_l2_distance': l2_distance(true_features[layer], noised_features[layer]).tolist(),
        'noise_stds': noise_stds,
    }
    return images, true_features, noised_features, info


class FeatureMetrics:
    def __init__(self, true_features: dict[str, torch.Tensor], target_features: dict[str, torch.Tensor]):
        assert len(true_features) == len(target_features) == 1, "Only one layer is supported for this script."
        self.layer = next(iter(true_features.keys()))
        self.true_features = true_features
        self.target_features = target_features

    def __call__(self, features: dict[str, torch.Tensor], _):
        """
        Calculate the evaluation metrics for the given features.
        Args:
            features (dict[str, torch.Tensor]): Reconstructed features.
        """
        results = {
            # true - recon metrics
            'true_recon_cosine_distance': cosine_distance(self.true_features[self.layer], features[self.layer]),
            'true_recon_correlation_distance': correlation_distance(self.true_features[self.layer], features[self.layer]),
            'true_recon_l2_distance': l2_distance(self.true_features[self.layer], features[self.layer]),

            # target - recon metrics
            'target_recon_cosine_distance': cosine_distance(self.target_features[self.layer], features[self.layer]),
            'target_recon_correlation_distance': correlation_distance(self.target_features[self.layer], features[self.layer]),
            'target_recon_l2_distance': l2_distance(self.target_features[self.layer], features[self.layer]),
        }
        results = {k: v.tolist() for k, v in results.items()} 
        return results


class PixelMetrics:
    def __init__(self, true_images: torch.Tensor, domain: image_domain.ImageDomain):
        """
        Args:
            true_images (torch.Tensor): True images in the common domain.
        """
        self.true_images = true_images
        self.domain = domain

    def __call__(self, generated_images: torch.Tensor):
        if self.domain is not None:
            generated_images = self.domain.receive(generated_images)
        results = {
            'pixel_correlation_distance': correlation_distance(self.true_images, generated_images),
            'pixel_cosine_distance': cosine_distance(self.true_images, generated_images),
            'pixel_l2_distance': l2_distance(self.true_images, generated_images),
        }
        results = {k: v.tolist() for k, v in results.items()}  # convert to list for JSON serialization
        return results


def resolve_path(output_dir: str, parameters: list[dict[str, Any]]):
    """
    Resolve the saving path and snapshot saver for the given parameters.
    """
    # path for each parameter
    path_list = [
        os.path.join(
            output_dir,
            param['layer'],
            f'corr_dist_{param["target_corr_dist"]}',
            f'noise_seed_{param["noise_seed"]}',
            param['image_name']
        )
        for param in parameters
    ]

    # snapshot settings
    snapshot_path_templates = [
        os.path.join(path, 'snapshot', 'step_{step}.png')
        for path in path_list
    ]
    snapshot_saver = IntervalSnapshotSaver(
        path_templates=snapshot_path_templates,
        save_steps=1000,
    )
    return path_list, snapshot_saver


def resolve_wandb(config: dict[str, Any], parameters: list[dict[str, Any]]):
    """
    If wandb is configured, initialize wandb project and return the metric prefix for each sample.

    Args:
        config (dict[str, Any]): Configuration dictionary.
        parameters (list[dict[str, Any]]): List of parameters for each experiment.
    Returns:
        use_wandb (bool): Whether to use wandb for logging.
        prefix (list[str] | None): List of names for each sample for wandb logging, or None if not using wandb.
    """
    if config.get('wandb', False):
        # project name and run name of this run
        project = 'vision_readout_' + parse_model_name(config) + '_' + config['data']['dataset_name']
        name = config['exp_name'] + '_' + parameters[0]['layer']  # assuming layers are shared across samples in a batch
        wandb.init(project=project, name=name, config=config)

        # run names prefix for each sample
        prefix = [p['image_name'] + f'_distance{p["target_corr_dist"]}' + f'_seed{p["noise_seed"]}' for p in parameters]

        return True, prefix
    else:
        return False, None


def save_results(
    save_dirs: list[str],
    info: dict[str, list],
    parameters: list[dict[str, Any]],
    generated_images: torch.Tensor,
    pipeline: FeatureInversionPipeline,
    target_features: Optional[dict[str, torch.Tensor]] = None
):
    """
    Save the final results of the experiment.

    Args:
        save_dirs (list[str]): List of directories to save the results for each parameter set.
        info (dict[str, list]): Dictionary containing information about the noise and true-target distances.
        parameters (list[dict[str, Any]]): List of parameters for each experiment.
        generated_images (torch.Tensor): Generated images tensor with shape (batch, 3, h, w).
        pipeline (FeatureInversionPipeline): The pipeline used for feature inversion.
        noised_features (dict[str, torch.Tensor]): Noised (target) features.
    """
    history = pipeline.history  # list[dict[str, list[float]]], where list contains metrics of samples
    summaries = []
    histories = []
    for i in range(len(parameters)):
        # full history for this run
        hist = [
            {name: value[i] for name, value in h.items()} for h in history
        ]

        # short summary dictionary for this run
        # parameters + initial metrics + final metrics
        summary = parameters[i]
        initial_metric = {name: value[i] for name, value in info.items()}
        summary.update(initial_metric)
        summary.update(hist[-1])
        summaries.append(summary)

        # convert history to DataFrame
        hist = pd.DataFrame(hist)
        histories.append(hist)

    # save them
    for i, run_path in enumerate(save_dirs):
        image_path = os.path.join(run_path, 'final.png')
        hist_path = os.path.join(run_path, 'history.csv')
        sum_path = os.path.join(run_path, 'summary.json')
        feature_path = os.path.join(run_path, 'target_features.pt')

        # save images
        image = generated_images[i].detach().cpu().numpy().astype(np.uint8)
        # crop the image to 224x224
        h, w, _ = image.shape
        h_start = (h - 224) // 2
        w_start = (w - 224) // 2
        image = image[h_start : h_start + 224, w_start : w_start + 224]
        image = Image.fromarray(image)
        os.makedirs(os.path.dirname(image_path), exist_ok=True)
        image.save(image_path)

        # save history as csv
        histories[i].to_csv(hist_path, index=False)
        with open(sum_path, 'w') as f:
            json.dump(summaries[i], f, indent=4)

        # save features
        if target_features is not None:
            layer = parameters[i]['layer']
            feature = target_features[layer][i].detach().cpu()
            torch.save(feature, feature_path)
    return summaries


def main(config: dict[str, Any], device: str) -> None:
    # instantiate dataset class
    image_dataset = DictImageDataset(
        root_path=config['data']['image_dir'],
        extension=config['data']['image_ext'],
        stimulus_names=config['data']['image_names'],
        transform=Compose([Resize((224, 224)), ToTensor()])
    )
    # path to the feature directory
    feature_dataset = DictFeaturesDataset(
        root_path=parse_feature_dir(config),
        layer_path_names=config['layers'],
        stimulus_names=config['data']['image_names'],
        transform=None,
        return_type='tensor'
    )

    # Load encoder class: we have to initialize the encoder for each layer setting
    model, domain = load_model_and_domain(config['model'], device=device, dtype=DTYPE)
    encoder_class = load_encoder_class(config['model'])

    # critic
    critic = load_critic(config['critic'])

    # resolve path
    output_dir = parse_output_dir(config)  # base output directory of all experiments under this config
    os.makedirs(output_dir, exist_ok=True)
    # save the config to the output directory
    with open(os.path.join(output_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f)

    # initialize the experiment database
    exp_db = initialize_experiment_db(config, output_dir)

    # run experiments
    # different layers are always separated due to the encoder implementation.
    for layer in config['layers']:
        # wrap model with the encoder class for feature extraction
        encoder = encoder_class(
            feature_network=model, 
            layer_names=[layer], 
            layer_mapping=config['layer_mapping'], 
            domain=domain, 
            device=device,
        )

        while True:
            # fetch parameters from the database: list[dict{str, Any}]
            parameters = exp_db.get_next_experiments(config['batch_size'], layer=layer)
            if not parameters:
                print(f"No more parameters to process for layer {layer}.")
                break

            try:
                # instantiate the generator and optimizer
                # instantiate the generator because it depends on the batch size
                generator = load_generator(config['generator'], batch_size=len(parameters), dtype=DTYPE, device=device)
                optimizer, scheduler = load_optimizer_and_scheduler(config['optimizer'], generator)

                # Load images and features, and calculate the metrics on target features
                images, true_features, noised_features, info = load_data(config, parameters, image_dataset, feature_dataset, device, DTYPE)

                # resolve saving path and wandb settings
                save_dirs, snapshot_saver = resolve_path(output_dir, parameters)
                use_wandb, wandb_names = resolve_wandb(config, parameters)

                # initialize the feature inversion pipeline
                pipeline = FeatureInversionPipeline(
                    generator=generator,
                    encoder=encoder,
                    critic=critic,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    log_interval=-1,
                    with_wandb=use_wandb,
                    eval_metrics=FeatureMetrics(true_features, target_features=noised_features),
                    snapshot_saver=snapshot_saver,
                    pixel_eval_metrics=PixelMetrics(images, domain=image_domain.CenterCropDomain((224, 224))),
                    **config.get('pipeline', {})
                )
                pipeline.reset_states()
                generated_images = pipeline(noised_features, wandb_names=wandb_names)
                generated_images = image_domain.finalize(generated_images)

                # save resutls and update the experiment status
                # summaries: list[dict[str, Any]] paraemeters + summary of metrics
                if config.get('save_target_features', False):
                    summaries = save_results(save_dirs, info, parameters, generated_images, pipeline, noised_features)
                else:
                    summaries = save_results(save_dirs, info, parameters, generated_images, pipeline)
                # update the experiment status in the database
                exp_db.update_experiment_info(summaries)
                exp_db.mark_experiment_status(parameters, 'finished')

                if use_wandb:
                    wandb.finish()

            except KeyboardInterrupt:
                print("Interrupted. Reverting status to 'pending'.")
                exp_db.mark_experiment_status(parameters, 'pending')
                if use_wandb:
                    wandb.finish()
                raise
            except Exception as e:
                print(f'n\Error: {e}. Marking status as "error".')
                traceback.print_exc()
                exp_db.mark_experiment_status(parameters, 'error')
                if use_wandb:
                    wandb.finish()

        # delete encoder for this layer to avoid feature_extractor memory leak
        del encoder.feature_extractor
        del encoder


def find_max_batch_size(config: dict, device: str, limit: int = 704):
    """
    Find the maximum batch size that can be processed on a given device
    """
    import tempfile

    # instantiate dataset class
    image_dataset = DictImageDataset(
        root_path=config['data']['image_dir'],
        extension=config['data']['image_ext'],
        stimulus_names=config['data']['image_names'],
        transform=Compose([Resize((224, 224)), ToTensor()])
    )
    # path to the feature directory
    feature_dataset = DictFeaturesDataset(
        root_path=parse_feature_dir(config),
        layer_path_names=config['layers'],
        stimulus_names=config['data']['image_names'],
        transform=None,
        return_type='tensor'
    )

    # Load encoder class: we have to initialize the encoder for each layer setting
    print('Loading model...')
    model, domain = load_model_and_domain(config['model'], device=device, dtype=DTYPE)
    encoder_class = load_encoder_class(config['model'])
    critic = load_critic(config['critic'])
    print('Model and tokenizer loaded.')

    # override config parameters
    config['exp_name'] = 'find_max_batch_size'
    config['pipeline']['num_iterations'] = 32

    for layer in config['layers']:
        # output directory
        output_dir = tempfile.mkdtemp(prefix='find_max_batch_size_')
        os.makedirs(output_dir, exist_ok=True)

        # initialize the experiment database
        exp_db = initialize_experiment_db(config, output_dir)

        # initial guess for batch size
        start = 10
        low = start
        high = start

        # Exponential phase
        while True:
            fits = batch_size_fits(config, exp_db, high, device, layer, critic, model, encoder_class, domain, image_dataset, feature_dataset)
            if fits:
                low = high
                high *= 2
                if high > limit:
                    high = limit
                    break
            else:
                break

        # Binary search phase
        while low < high:
            mid = (low + high + 1) // 2
            fits = batch_size_fits(config, exp_db, mid, device, layer, critic, model, encoder_class, domain, image_dataset, feature_dataset)
            if fits:
                low = mid
            else:
                high = mid - 1

        print(f"[RESULT] Maximum batch size that fits layer_{layer}: {low}")


def batch_size_fits(config: dict[str, Any], exp_db, batch_size, device, layer, critic, model, encoder_class, domain, image_dataset, feature_dataset) -> bool:
    # remove cache
    import gc
    gc.collect()
    torch.cuda.empty_cache()

    print('Testing batch size:', batch_size)
    fits = True
    parameters = exp_db.get_next_experiments(batch_size, layer=layer)
    if len(parameters) < batch_size:
        print('WARNING: batch size hit the limit of the samples in the database.')
        print('Actual batch size:', len(parameters))

    # wrap model with the encoder class for feature extraction
    encoder = encoder_class(
        feature_network=model, 
        layer_names=[layer], 
        layer_mapping=config['layer_mapping'], 
        domain=domain, 
        device=device,
    )
    try:
        # instantiate the generator and optimizer
        # instantiate the generator because it depends on the batch size
        generator = load_generator(config['generator'], batch_size=len(parameters), dtype=DTYPE, device=device)
        optimizer, scheduler = load_optimizer_and_scheduler(config['optimizer'], generator)

        # Load images and features, and calculate the metrics on target features
        images, true_features, noised_features, info = load_data(config, parameters, image_dataset, feature_dataset, device, DTYPE)

        # initialize the feature inversion pipeline
        pipeline = FeatureInversionPipeline(
            generator=generator,
            encoder=encoder,
            critic=critic,
            optimizer=optimizer,
            scheduler=scheduler,
            log_interval=-1,
            with_wandb=False,
            eval_metrics=FeatureMetrics(true_features, target_features=noised_features),
            snapshot_saver=None,
            pixel_eval_metrics=PixelMetrics(images, domain=image_domain.CenterCropDomain((224, 224))),
            **config.get('pipeline', {})
        )
        pipeline.reset_states()
        _ = pipeline(noised_features, wandb_names=None)

    except KeyboardInterrupt:
        print("Interrupted. Reverting status to 'pending'.")
        raise
    except  RuntimeError as e:
        print(e)
        fits = False
    except Exception as e:
        print(f'Error: {e}.')
        traceback.print_exc()
        if 'memory' in str(e).lower():
            fits = False
        else:
            raise
    finally:
        # undo the experiment db status change anyway
        exp_db.mark_experiment_status(parameters, 'pending')

    if fits:
        print('Batch size', batch_size, 'fits.')
    else:
        print('Batch size', batch_size, 'does not fit.')
    return fits
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    parser.add_argument("--device", type=str, default='cuda')
    parser.add_argument("--find_max_batch_size", action='store_true', help="Find the maximum batch size that fits on the device.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    if args.find_max_batch_size:
        print("Finding maximum batch size...")
        find_max_batch_size(config, args.device)
    else:
        main(config, device=args.device)