import dataclasses
from random import random

import numpy as np
import torch

import mqar_zoology.associative_recall as zoology


# PAD = 0


@dataclasses.dataclass
class MqarBatch:
    x_ids: torch.Tensor
    y_true_ids: torch.Tensor
    V: int
    L: int
    N_facts: int
    size: int


# Visualization

def visualize_sample(context, baseline):

    def tok_str(x):
        if x == PAD:
            return '.'
        return str(x)

    seq = [tok_str(x) for x in context]
    out = ' '.join(seq)

    bseq = [tok_str(x) for x in baseline]
    out += ' | ' + ' '.join(bseq)

    print(out)


# Wrappers

def generate_mqar_batch(
    V: int,
    L: int,
    N_facts: int | list[int],
    batch_size: int,
    seed: int = None,

    # additional kwargs: (zoology)
    power_a: float = 1.0,  # for non-uniform distribution set to 0.01 or other values
    random_non_queries: bool = False,
    include_slices: bool = False,

) -> MqarBatch:

    rng = np.random.RandomState(seed)

    pseudorandom_seed = rng.randint(2 ** 31)

    def _get_mqar_batch(num_kv_pairs: int, num_examples: int):
        return zoology.multiquery_ar(
            vocab_size=V,
            num_examples=num_examples,
            input_seq_len=L,
            num_kv_pairs=num_kv_pairs,
            power_a=power_a,
            random_non_queries=random_non_queries,
            seed=pseudorandom_seed,
            include_slices=include_slices,
        )

    if type(N_facts) is int:
        data = _get_mqar_batch(num_kv_pairs=N_facts, num_examples=batch_size)
        x = data.inputs
        y = data.labels
        mean_N_facts = N_facts

    elif type(N_facts) is list:
        N_facts_list = np.arange(min(N_facts), max(N_facts)+1)  # TODO: move / make better
        x_list, y_list = [], []
        for _ in range(batch_size):
            n = np.random.choice(N_facts_list)
            # assert type(n) is int
            data = _get_mqar_batch(num_kv_pairs=n, num_examples=1)
            x_list.append(data.inputs)
            y_list.append(data.labels)
        x = torch.vstack(x_list)
        y = torch.vstack(y_list)
        mean_N_facts = np.mean(N_facts)

        # print(x.shape, y.shape)

    else:
        raise ValueError("N_facts should be either an int or a list of ints")

    # y[y == zoology.IGNORED_TOKEN] = PAD

    batch = MqarBatch(
        x_ids=x,
        y_true_ids=y,
        V=V,
        L=L,
        N_facts=mean_N_facts,
        size=batch_size,
    )

    return batch


if __name__ == '__main__':

    N_facts = 8
    # N_facts = 16

    N_queries = N_facts

    L_facts = int(2 * N_facts)
    L_queries = int(2 * N_queries)

    # L_other = 0
    L_other = 10

    L = L_facts + L_queries + L_other

    V = 100
    # V = 1000
    # V = 8192

    batch_size = 5

    args = dict(
        V=V,
        L=L,
        N_facts=N_facts,
        batch_size=batch_size,
    )

    kwargs = dict(
        power_a=1.0,
        random_non_queries=True,
        include_slices=False,
        seed=0,
    )

    x, y = generate_mqar_batch(**args, **kwargs)

    x = x.cpu().numpy()
    y = y.cpu().numpy()

    for i in range(args['batch_size']):
        visualize_sample(x[i], y[i])
