import torch
from omegaconf import DictConfig
from hydra.utils import instantiate
from torch.utils.data import DataLoader
from utils.wandb_utils import download_run
from experiment.src_ace.cf_dataset import CFDataset

import utils

import logging
log = logging.getLogger(__name__)


def get_fabric(config):
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric


def preprocess(x):
    return (x - 0.5) * 2

# example id for testing
# x3mwveru

def run(config: DictConfig):
    '''
    Computes LPIPS between the original image and VCE
    '''
    utils.preprocess_config(config)
    assert config.exp.run_id is not None, "Please provide a run_id to calculate the metric"

    log.info(f'Launching Fabric')
    fabric = get_fabric(config)

    log.info(f'Building components')
    lpips = instantiate(config. lpips)
    lpips = fabric.setup(lpips)

    log.info("Connecting to wandb")
    area_thr = config.exp.area_threshold
    project_name = f"{config.wandb.entity}/{config.wandb.project}"

    _, image_pairs, _ = download_run(project_name, config.exp.run_id, area_threshold=area_thr)
 
    dataset = CFDataset(image_pairs)
    loader = DataLoader(
        dataset, batch_size=config.exp.batch_size,
        shuffle=False, num_workers=0
    )
    loader = fabric.setup_dataloaders(loader)

    log.info("Calculating LPIPS")

    scores = []

    with torch.no_grad():
        for batch in loader:
            batch_0, batch_1 = batch
            lpips_val = lpips(
                preprocess(batch_0), preprocess(batch_1)).flatten().tolist()
            scores.extend(lpips_val)

    log.info(f"Mean LPIPS: {sum(scores) / len(scores)}")

    return 0