import logging
from collections import defaultdict

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from evaluation.images.embedder import get_img_from_cpt, ImageEmbedder, get_cpt_from_img

logger = logging.getLogger('custom')


class VarianceCalculator:
    def __init__(self, model, dataset, device):
        self.model = model
        self.dataset = dataset
        self.device = device

    def run(self, its=1, bs=10, **kwargs):
        """
        :param bs: Note that we will compute 100 importance samples
        per batch sample, i.e., the bs must be small
        """
        logger.info(f'Performing {its} iterations')
        variances = defaultdict(list)
        for _ in tqdm(range(its)):
            cur_v = self._run(bs, **kwargs)
            for k, v in cur_v.items():
                variances[k].append(v)
        variances = self.process_results(variances)
        return variances

    def process_results(self, variances):
        logger.info(f'Average variances of generated features to their mean:')
        for k, v in variances.items():
            logger.info(f'{k}: {np.average(v):.3f} +/- {np.std(v):.3f}')
            variances[k] = np.average(v)
        return variances

    def _run(self, bs, mod, **kwargs):
        """
        Computes intra-class variance of caption embeddings.
        (separately for each sample)
        """
        loader = DataLoader(self.dataset, batch_size=bs)
        r = ImageEmbedder('resnet', self.device)
        # rf = ImageEmbedder('resnet_ft', self.device)

        # Fetch values iteratively due to high k
        vs = defaultdict(list)
        for inp in loader:
            x, _ = inp
            x1, x2 = x

            # Caption to image space
            if mod == 'x1':
                x = self._get_img_from_cpt(
                    self.model, x2, self.device, k=100, **kwargs
                )
                vs['x1|x2'].extend(self._compute_variances(x))
                vs['x1|x2_emb'].extend(self._compute_variances(r.run(x)))
                # vs['x1|x2_emb_ft'].extend(self._compute_variances(rf.run(x)))
                del x

            # Image to caption space
            elif mod == 'x2':
                x = self._get_cpt_from_img(
                    self.model, x1, self.device, k=100, **kwargs
                )
                vs['x2|x1'].extend(self._compute_variances(x))

            else:
                raise ValueError

        vs = {k: np.mean(v) for k, v in vs.items()}

        return vs

    def _get_img_from_cpt(self, *args, **kwargs):
        return get_img_from_cpt(*args, **kwargs)

    def _get_cpt_from_img(self, *args, **kwargs):
        return get_cpt_from_img(*args, **kwargs)

    @staticmethod
    def _compute_variances(x):
        """
        :param x: N x K x D
        :return: List(float)
        """
        assert any([
            len(x.size()) == 3,  # feature space
            len(x.size()) == 5,  # image space
        ])
        variances = []
        # For every batch sample
        for ft in x:
            # Compute mean over K
            prototype = ft.mean(0)
            var = torch.var(prototype - ft, dim=0)
            var = var.mean().item()
            variances.append(var)
        return variances
