"""Calculate FID score for sample images and reference images."""
import argparse
import json
import os

from pytorch_fid import fid_score

from src.utils import print_args


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Calculate model behavior scores")
    parser.add_argument(
        "--sample_dir",
        type=str,
        help="directory path of samples generated by a model",
        required=True,
    )
    parser.add_argument(
        "--reference_dir",
        type=str,
        help="directory path of reference samples, from a dataset or a diffusion model",
        required=True,
    )
    parser.add_argument(
        "--db",
        type=str,
        help="filepath of database for recording scores",
        required=True,
    )
    parser.add_argument(
        "--exp_name",
        type=str,
        help="experiment name to record in the database file",
        default=None,
        required=True,
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        help="batch size for computation",
        default=512,
    )
    parser.add_argument(
        "--device", type=str, help="device used for computation", default="cuda:0"
    )
    args = parser.parse_args()
    return args


def main(args):
    """Main function for calculating global model behaviors."""

    sample_dir = args.sample_dir
    info_dict = vars(args)

    # Check if subdirectories exist for conditional image generation.
    subdir_list = [
        entry
        for entry in os.listdir(sample_dir)
        if os.path.isdir(os.path.join(sample_dir, entry))
    ]
    if len(subdir_list) == 0:
        # Aggregate FID score. This is the standard practice even for conditional image
        # generation. For example, see
        # https://huggingface.co/docs/diffusers/main/en/conceptual/evaluation#class-conditioned-image-generation
        print("Calculating the FID score...")
        fid_value = fid_score.calculate_fid_given_paths(
            paths=[sample_dir, args.reference_dir],
            batch_size=args.batch_size,
            device=args.device,
            dims=2048,
        )
        fid_value_str = f"{fid_value:.4f}"
        print(f"FID score: {fid_value_str}")
        info_dict["fid_value"] = fid_value_str

    else:
        # Class-wise FID scores. If each class has too few reference samples, the
        # scores can be unstable.
        avg_fid_value = 0
        for subdir in subdir_list:
            print(f"Calculating the FID score for class {subdir}...")
            fid_value = fid_score.calculate_fid_given_paths(
                paths=[
                    os.path.join(sample_dir, subdir),
                    os.path.join(args.reference_dir, subdir),
                ],
                batch_size=args.batch_size,
                device=args.device,
                dims=2048,
            )
            fid_value_str = f"{fid_value:.4f}"
            avg_fid_value += fid_value

            print(f"FID score for {subdir}: {fid_value_str}")
            info_dict[f"fid_value/{subdir}"] = fid_value_str

        avg_fid_value /= len(subdir_list)
        avg_fid_value_str = f"{avg_fid_value:.4f}"
        print(f"Average FID score: {avg_fid_value_str}")
        info_dict["avg_fid_value"] = avg_fid_value_str

    with open(args.db, "a+") as f:
        f.write(json.dumps(info_dict) + "\n")
    print(f"Results saved to the database at {args.db}")


if __name__ == "__main__":
    args = parse_args()
    print_args(args)
    main(args)
    print("Done!")
