from omegaconf import DictConfig
from hydra.utils import instantiate
from .src_ace.fid import calculate_fid_given_image_sets
from utils.wandb_utils import download_run
from tqdm import tqdm
import numpy as np
from pathlib import Path
import shutil

import utils

import logging
log = logging.getLogger(__name__)

def get_fabric(config):
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric


def prepare_split(real_cf_pairs: list[tuple[Path, Path]]) -> tuple[list[Path], list[Path]]:
    first_list = []
    second_list = []

    random_idxs = np.arange(len(real_cf_pairs))
    np.random.shuffle(random_idxs)

    for i, idx in enumerate(random_idxs):
        if (len(real_cf_pairs) % 2 == 0) and i == len(real_cf_pairs):
            # Make sure we have an even number of images and counterfactuals
            break

        if i % 2 == 0:
            first_list.append(real_cf_pairs[idx][0])
        else:
            second_list.append(real_cf_pairs[idx][1])

    return first_list, second_list


def run(config: DictConfig):
    utils.preprocess_config(config)
    assert config.exp.run_id is not None, "Please provide a run_id to calculate the metric"

    log.info(f'Launching Fabric')
    fabric = get_fabric(config)

    log.info(f'Building components')
    classifier = instantiate(config.classifier)
    classifier = fabric.setup(classifier)

    log.info("Connecting to wandb")
    project_name = f"{config.wandb.entity}/{config.wandb.project}"

    run_dir, image_pairs, _ = download_run(project_name, config.exp.run_id, area_threshold=config.exp.area_threshold)
    if config.exp.another_run_id is not None:
        _, another_image_pairs, _ = download_run(project_name, config.exp.another_run_id, area_threshold=config.exp.area_threshold)

        image_pairs.extend(another_image_pairs)

    log.info("Calculating sFID")
    assert config.sfid.dims in (64, 192, 768, 2048)
    
    from PIL import ImageFile
    ImageFile.LOAD_TRUNCATED_IMAGES = True

    fids = []
    for i in tqdm(range(config.sfid.repeats), desc="Calculating sFID"):
        split_real_image, split_cf_images = prepare_split(image_pairs)

        fid = calculate_fid_given_image_sets(
            [
                split_real_image,
                split_cf_images
            ],
            config.sfid.batch_size,
            fabric.device,
            dims = config.sfid.dims
        )
        fids.append(fid)

    fids = np.array(fids)
    log.info(f"sFID: {np.mean(fids)}")


    return 0