from omegaconf import DictConfig
from hydra.utils import instantiate
from utils.wandb_utils import download_run
from experiment.src_ace.cf_dataset import CFDataset
from experiment.src_ace.simsiam import get_simsiam_dist
from torch.utils.data import DataLoader
from torch.nn.functional import cosine_similarity

from tqdm import tqdm

import numpy as np
import torch
import utils

import logging
log = logging.getLogger(__name__)

def get_fabric(config):
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric


@torch.no_grad()
def compute_FVA(oracle, dataloader):
    dists = []
    for cl, cf in tqdm(dataloader):
        dists.append(oracle(cl, cf).cpu().numpy())

    return np.concatenate(dists)



def run(config: DictConfig):
    utils.preprocess_config(config)
    assert config.exp.run_id is not None, "Please provide a run_id to calculate the metric"

    log.info(f'Launching Fabric')
    fabric = get_fabric(config)

    log.info(f'Building components')
    classifier = instantiate(config.classifier)
    classifier = fabric.setup(classifier)

    log.info("Connecting to wandb")

    project_name = f"{config.wandb.entity}/{config.wandb.project}"
    run_dir, real_cf_pairs, run_config = download_run(project_name, config.exp.run_id, area_threshold=config.exp.area_threshold)

    assert len(real_cf_pairs) > 0, "No real-cf pairs found"

    dataset = CFDataset(real_cf_pairs, normalize_for_s3=True)
    loader = DataLoader(
        dataset, batch_size=config.s3.batch_size,
        shuffle=False, num_workers=0
    )

    loader = fabric.setup_dataloaders(loader)

    simsiam_model = get_simsiam_dist(config.s3.weights_path)
    simsiam_model = fabric.setup(simsiam_model)

    results = compute_FVA(simsiam_model, loader)

    log.info('SimSiam Similarity: {:>4f}'.format(np.mean(results).item()))
    return 0