import torch
from omegaconf import DictConfig
from hydra.utils import instantiate
from torch.utils.data import DataLoader
from utils.wandb_utils import download_run
from experiment.src_ace.cf_dataset import CFDataset

import utils

import logging
log = logging.getLogger(__name__)

def get_fabric(config):
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric

# example ids for testing
# c2y5y5sn
# dus1sjd5

def preprocess(x):
    return (x - 0.5) * 2


def run(config: DictConfig):
    '''
    Computes diversity, i.e., mean LPIPS between VCEs for the same image
    '''
    utils.preprocess_config(config)
    assert config.exp.run_id is not None, "Please provide a run_id to calculate the metric"
    assert config.exp.another_run_id is not None, "Please provide a run_id to calculate the metric"

    log.info(f'Launching Fabric')
    fabric = get_fabric(config)

    log.info(f'Building components')
    lpips = instantiate(config. lpips)
    lpips = fabric.setup(lpips)

    log.info("Connecting to wandb")
    area_thr = config.exp.area_threshold
    project_name = f"{config.wandb.entity}/{config.wandb.project}"
    
    if config.exp.run_id == config.exp.another_run_id:
        _, image_pairs, _ = download_run(project_name, config.exp.run_id, area_threshold=area_thr)

        cf_images = [pair[1] for pair in image_pairs]
        cf_images_another = [pair[0] for pair in image_pairs]

    else:
        _, image_pairs, _ = download_run(project_name, config.exp.run_id, area_threshold=area_thr)
        _, another_image_pairs, _ = download_run(project_name, config.exp.another_run_id, area_threshold=area_thr)
    
        cf_images = [pair[1] for pair in image_pairs]
        cf_images_another = [pair[1] for pair in another_image_pairs]

    cf_pairs = [(e1, e2) for e1, e2 in zip(cf_images, cf_images_another)]
    dataset = CFDataset(cf_pairs)
    loader = DataLoader(
        dataset, batch_size=config.exp.batch_size,
        shuffle=False, num_workers=0
    )
    loader = fabric.setup_dataloaders(loader)

    log.info("Calculating diversity")

    scores = []

    with torch.no_grad():
        for batch in loader:
            batch_0, batch_1 = batch
            lpips_val = lpips(
                preprocess(batch_0), preprocess(batch_1)).flatten().tolist()
            scores.extend(lpips_val)

    log.info(f"Mean diversity: {sum(scores) / len(scores)}")

    return 0