"""
Calculate representation size and save them in {exp_dir}/summary/representation_size.

Usage 
```
python scripts/250527_readout_vision/calculate_size.py /path/to/config.yaml
```
"""
import argparse
import os
from typing import Any

import yaml
import pandas as pd


def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    return config['model']['pretrained'].replace('/', '_')


def parse_output_dir(config: dict[str, Any]) -> str:
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = parse_model_name(config)
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_vae', 'results', model_name, dataset_name, exp_name)


def calculate_representation_size(
        results: pd.DataFrame,
        dx_col: str,
        dh_col: str,
        dx_thres: float,
        fill_na_with_zero: bool = True
    ) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Calculate the representation size based on the results DataFrame.

    Args:
        results (pd.DataFrame): DataFrame containing the experiment results.
            The dataframe should have the following columns:
                - 'layer': Layer name.
                - 'image_name': Sample name.
                - dx_col: Column name for the input distance values.
                - dh_col: Column name for the feature distance values.
        dx_col (str): Column name for the input distance values.
        dh_col (str): Column name for the feature distance values.
        dx_thres (float): Threshold value for feature distance to determine the representation size.
        fill_na_with_zero (float): Wether or not to fill nan values with zero when aggregating size across images.
    
    Returns:
        size (pd.DataFrame): Representation size for each sample and layer.
            It has following columns: layer, image_name, and representation_size.
            The size for a layer, image_name pair didn't reach the threshold is nan.
        size_agg (pd.DataFrame): Representation size aggregated across images for each layer.
            It has following columns: layer, size_mean, size_std.
    """
    # helper function for pd.DataFrame.groubpy.apply
    def max_dh(x: pd.DataFrame) -> float:
        """
        Get the maximum true target correlation distance for a given layer and image.
        """
        return x.loc[x[dx_col] < dx_thres, dh_col].max()
    # pd.series
    size = results.groupby(['image_name']).apply(max_dh)

    # aggregate across images
    mean = size.fillna(0).mean()
    std = size.fillna(0).std()
    # columns: layer, size_mean, size_std
    size_agg = pd.DataFrame({'size_mean': mean, 'size_std': std}, index=['latent'])
    return size, size_agg


def main(config: dict, dx_col: str, dh_col: str, dx_thres: float) -> None:
    # directory settings
    exp_dir = parse_output_dir(config)
    output_dir = os.path.join(exp_dir, 'summary', 'representation_size')
    os.makedirs(output_dir, exist_ok=True)
    size_path = os.path.join(output_dir, 'size.csv')
    size_agg_path = os.path.join(output_dir, 'size_aggregated.csv')

    # load results from experiment_db.json
    results = pd.read_json(os.path.join(exp_dir, 'experiment_db.json'))
    size, size_agg = calculate_representation_size(
        results,
        dx_col=dx_col,
        dh_col=dh_col,
        dx_thres=dx_thres,
        fill_na_with_zero=True,
    )

    # save results
    size.to_csv(size_path, index=False)
    size_agg.to_csv(size_agg_path, index=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    parser.add_argument('--dx_col', type=str, default='pixel_correlation_distance')
    parser.add_argument('--dh_col', type=str, default='target_corr')
    parser.add_argument('--dx_thres', type=float, default=0.1, help="Threshold for feature distance to determine representation size.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    main(
        config,
        dx_col=args.dx_col,
        dh_col=args.dh_col,
        dx_thres=args.dx_thres,
    )