import logging
import os
from collections import defaultdict
from typing import Dict

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

import utils
from utils import chunks

logger = logging.getLogger('custom')


class AccuracyCalculator:
    """ Calculates unimodal and crossmodal accuracies in observed space
    (not latent space). """

    def __init__(self, model, dataset, device, save_dir, split):
        self.model = model
        self.dataset = dataset
        self.device = device
        self.save_dir = save_dir
        self.split = split

    def run(self):
        output = self._get_model_output()
        accs = {}
        accs.update(self._get_crossmodal_accs(output))
        accs.update(self._get_unimodal_accs(output))
        self._finalize(accs)

    def _finalize(self, accs):
        for k, v in accs.items():
            logger.info(f'Accuracy for {k}: {v:.2f}%')
        save_path = os.path.join(
            self.save_dir,
            f'accuracies_in_observed_space_{self.split}.pt')
        torch.save(accs, save_path)

    def _get_unimodal_accs(self, output):
        accs = {}
        mods = ['x1', 'x2']
        for i, m in enumerate(mods):
            idx = list(range(self.dataset.x[i].size(0)))
            y = list(utils.to_np(self.dataset.s['y']))
            x_train, x_test, y_train, y_test = train_test_split(idx, y,
                                                                test_size=0.5)
            acc = calculate_accuracy(
                support={'x': output[f'{m}|{m}'][x_train],
                         'y': torch.tensor(y_train)},
                query={'x': self.dataset.x[i][x_test],
                       'y': torch.tensor(y_test)},
                device=self.device)
            accs[f'{m}|{m}'] = acc
        return accs

    def _get_crossmodal_accs(self, output):
        accs = {}
        mods = ['x1', 'x2']
        for t, c in zip(mods, mods[::-1]):
            acc = calculate_accuracy(
                support={'x': output[f'{t}|{c}'],
                         'y': self.dataset.s['y']},
                query={'x': self.dataset.x[0 if t == 'x1' else 1],
                       'y': self.dataset.s['y']},
                device=self.device)
            accs[f'{t}|{c}'] = acc
        return accs

    def _get_model_output(self, bs=128):
        """ Compute reconstruction means. """
        output = defaultdict(list)
        loader = DataLoader(self.dataset, bs, shuffle=False)
        for inp in loader:
            inp = utils.to_device(inp, self.device)
            x, _ = inp
            cur_output = self.model.forward(x, eval=True, k=1)
            output = self._get_means_from_generations(output, cur_output)
        for k, v in output.items():
            output[k] = torch.cat(v)
        return output

    @staticmethod
    def _get_means_from_generations(output, cur_output):
        mods = ['x1', 'x2']
        for t in mods:
            for c in mods:
                if t != c:
                    # Sample from prior over z
                    rec = cur_output['ancestral_samples'][t][c][0]['dist'].mean[0]
                else:
                    # Sample from posterior over z
                    rec = cur_output['reconstruction'][t][c]['dist'].mean[0]
                output[f'{t}|{c}'].append(rec)
        return output


def calculate_accuracy(support: Dict[str, torch.tensor],
                       query: Dict[str, torch.tensor],
                       device,
                       bs=256):
    idx = list(range(query['y'].size(0)))
    gen = chunks(idx, bs)
    accs = []
    for cur_idx in gen:
        cur_query_x = query['x'][cur_idx]
        cur_query_y = query['y'][cur_idx]
        acc = _calc_accuracy(
            support,
            query={'x': cur_query_x, 'y': cur_query_y},
            device=device)
        accs.extend(acc)
    acc = np.mean(accs) * 100
    return acc


def _calc_accuracy(support, query, device):
    d = utils.get_distance(support['x'], query['x'], device)  # N x Q
    idx = d.argmin(0)  # indices of nearest neighbors, (Q,)
    preds = support['y'][idx]
    acc = torch.eq(preds, query['y'])
    acc = list(utils.to_np(acc))
    return acc
