from omegaconf import DictConfig
from hydra.utils import instantiate
from pytorch_fid.fid_score import calculate_fid_given_paths
from utils.wandb_utils import download_run
from experiment.src_ace.cf_dataset import CFDataset
from experiment.src_ace.cout_metrics import evaluate
from torch.utils.data import DataLoader

import utils

import logging
log = logging.getLogger(__name__)

def get_fabric(config):
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric


def run(config: DictConfig):
    utils.preprocess_config(config)
    assert config.exp.run_id is not None, "Please provide a run_id to calculate the metric"

    log.info(f'Launching Fabric')
    fabric = get_fabric(config)

    log.info(f'Building components')
    classifier = instantiate(config.classifier)
    classifier = fabric.setup(classifier)

    log.info("Connecting to wandb")

    project_name = f"{config.wandb.entity}/{config.wandb.project}"
    run_dir, real_cf_pairs, run_config = download_run(project_name, config.exp.run_id, area_threshold=config.exp.area_threshold)

    target_id = run_config["exp"]["target_id"]
    guide_id = run_config["exp"]["guide_id"]

    assert len(real_cf_pairs) > 0, "No real-cf pairs found"

    dataset = CFDataset(real_cf_pairs)
    loader = DataLoader(
        dataset, batch_size=config.exp.batch_size,
        shuffle=False, num_workers=0
    )

    loader = fabric.setup_dataloaders(loader)

    log.info("Calculating COUT")
    cout = evaluate(
        target_id, guide_id, classifier, loader, fabric.device
    )
    log.info(f"COUT: {cout[0]}")

    return 0