from omegaconf import DictConfig
from hydra.utils import instantiate
from utils.wandb_utils import download_run
from pathlib import Path
from .src_ace.fid import calculate_fid_given_image_sets

import utils

import logging
log = logging.getLogger(__name__)

def get_fabric(config):
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric


def run(config: DictConfig):
    utils.preprocess_config(config)
    assert config.exp.run_id is not None, "Please provide a run_id to calculate the metric"

    log.info(f'Launching Fabric')
    fabric = get_fabric(config)

    log.info(f'Building components')
    classifier = instantiate(config.classifier)
    classifier = fabric.setup(classifier)

    log.info("Connecting to wandb")

    project_name = f"{config.wandb.entity}/{config.wandb.project}"

    _, image_pairs, _ = download_run(project_name, config.exp.run_id, area_threshold=config.exp.area_threshold)
    if config.exp.another_run_id is not None:
        _, another_image_pairs, _ = download_run(project_name, config.exp.another_run_id, area_threshold=config.exp.area_threshold)

        image_pairs.extend(another_image_pairs)

    real_images = [pair[0] for pair in image_pairs]
    cf_images = [pair[1] for pair in image_pairs]


    log.info("Calculating FID")
    from PIL import ImageFile
    ImageFile.LOAD_TRUNCATED_IMAGES = True

    fid = calculate_fid_given_image_sets(
        [ real_images, cf_images ],
        config.fid.batch_size,
        fabric.device,
        dims = config.fid.dims
    )
    log.info(f"FID: {fid}")




    return 0