"""Common stuff for human evaluations."""
import csv
import io
import itertools
import random
from typing import List, Tuple

from npeff_torch.examination.top_examples import top_examples_from_coeffs
from npeff_torch.examination.top_examples import top_examples_common

###############################################################################


def make_top_example_groups(
    rng: random.Random,
    reader: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
    n_groups: int,
    n_examples_per_group: int,
    shuffle_examples_in_group: bool,
    unique_top_examples: bool,
) -> Tuple[List[int], List[List['top_examples_common.TopExampleInfo']]]:
    n_components = reader.n_components

    component_indices = rng.sample(range(n_components), k=n_groups)

    example_groups = []
    for component_index in component_indices:
        if unique_top_examples:
            group = reader.get_unique_top_examples_for_component(component_index=component_index, n_top_examples=n_examples_per_group)
        else:
            group = reader.get_top_examples_for_component(component_index=component_index, n_top_examples=n_examples_per_group)
        if shuffle_examples_in_group:
            rng.shuffle(group)
        example_groups.append(group)

    return component_indices, example_groups


def make_random_example_groups(
    rng: random.Random,
    reader: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
    n_groups: int,
    n_examples_per_group: int,
) -> List[List['top_examples_common.TopExampleInfo']]:
    n_examples = reader.n_examples

    example_groups = []
    for _ in range(n_groups):
        example_indices = rng.sample(range(n_examples), k=n_examples_per_group)
        example_groups.append([
            # The component_index doesn't matter here.
            reader.make_top_example_info_by_indices(example_index=example_index, component_index=0)
            for example_index in example_indices
        ])

    return example_groups


###############################################################################


def _select_non_same_pairs_of_indices(
    rng: random.Random,
    n: int,
    k: int,
) -> List[Tuple[int, int]]:
    # TODO: Has quadratic space complexity in n.
    pair_options = []
    for a, b in itertools.combinations(range(n), r=2):
        pair_options.append((a, b))
        pair_options.append((b, a))

    return list(rng.sample(pair_options, k=k))


def make_same_component_example_group_pairs(
    rng: random.Random,
    reader_1: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
    reader_2: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
    n_groups: int,
    n_examples_per_group: int,
    *,
    shuffle_examples_in_group: bool,
    shuffle_pair_order: bool,
    unique_top_examples: bool,
) -> List[Tuple[List['top_examples_common.TopExampleInfo'], List['top_examples_common.TopExampleInfo']]]:
    assert reader_1.n_components == reader_2.n_components
    n_components = reader_1.n_components

    component_indices = rng.sample(range(n_components), k=n_groups)

    example_group_pairs = []
    for component_index in component_indices:
        pair = []
        for reader in (reader_1, reader_2):
            if unique_top_examples:
                group = reader.get_unique_top_examples_for_component(component_index=component_index, n_top_examples=n_examples_per_group)
            else:
                group = reader.get_top_examples_for_component(component_index=component_index, n_top_examples=n_examples_per_group)
            if shuffle_examples_in_group:
                rng.shuffle(group)
            pair.append(group)

        if shuffle_pair_order:
            rng.shuffle(pair)

        example_group_pairs.append(tuple(pair))

    return example_group_pairs


def make_different_component_example_group_pairs(
    rng: random.Random,
    reader_1: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
    reader_2: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
    n_groups: int,
    n_examples_per_group: int,
    *,
    shuffle_examples_in_group: bool,
    shuffle_pair_order: bool,
    unique_top_examples: bool,
) -> List[Tuple[List['top_examples_common.TopExampleInfo'], List['top_examples_common.TopExampleInfo']]]:
    assert reader_1.n_components == reader_2.n_components
    n_components = reader_1.n_components

    example_group_pairs = []
    for component_index_1, component_index_2 in _select_non_same_pairs_of_indices(rng, n_components, k=n_groups):
        pair = []
        for reader, component_index in ((reader_1, component_index_1), (reader_2, component_index_2)):
            if unique_top_examples:
                group = reader.get_unique_top_examples_for_component(component_index=component_index, n_top_examples=n_examples_per_group)
            else:
                group = reader.get_top_examples_for_component(component_index=component_index, n_top_examples=n_examples_per_group)
            if shuffle_examples_in_group:
                rng.shuffle(group)
            pair.append(group)

        if shuffle_pair_order:
            rng.shuffle(pair)

        example_group_pairs.append(tuple(pair))

    return example_group_pairs


###############################################################################


def make_evaluation_csv(
    headers: List[str],
    n_items: int,
) -> str:
    """
    The first column should be something like "Group", and its rows will be filled
    filled with the integers 1, ..., n_items.
    """
    data = [[*headers]]

    n_cols = len(headers)
    for i in range(n_items):
        data.append([str(i + 1), *((n_cols - 1) * "")])

    output = io.StringIO()
    writer = csv.writer(output)
    writer.writerows(data)
    
    return output.getvalue()
