"""Common code for developing NMF on BERT stuff."""
import collections
import functools
import os
from typing import Optional, Sequence, Union

import numpy as np
import tensorflow as tf
import torch
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm

from em.datasets import glue
from em.fishers import diagonal
# from em.fishers import per_example

from em.util import hdf5_util


if os.path.exists('/fruitbasket'):
    FISHER_DIR = '/fruitbasket/users/m/project_data/extract_merge1/fishers0'
    PER_EXAMPLES_FISHERS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/per_example_fishers0'
else:
    FISHER_DIR = os.path.expanduser('~/Desktop/projects_data/extract_merge1/fishers0')
    PER_EXAMPLES_FISHERS_DIR = None

SEQ_LEN = 128

###############################################################################


def separate_tf_torch_gpus():
    devices = os.environ["CUDA_VISIBLE_DEVICES"]
    devices = [int(d) for d in devices.split(',')]
    assert len(devices) >= 2
    torch.cuda.set_device(len(devices) - 1)

###############################################################################


def create_dataset_for_per_example_fishers(
    tokenizer: AutoTokenizer,
    batch_size: int,
    task: str,
    sequence_length: int = SEQ_LEN,
):
    
    glue_ds = glue.load_glue_dataset(
        task=task,
        split='train',
        tokenizer=tokenizer,
        max_length=sequence_length,
    )
    return glue_ds.repeat().shuffle(1000).batch(batch_size)


def _get_sparse_indices(indices_or_sparse: Union[tf.Tensor, tf.sparse.SparseTensor]) -> tf.Tensor:
    if isinstance(indices_or_sparse, tf.sparse.SparseTensor):
        return indices_or_sparse.indices
    return indices_or_sparse


@tf.function
def _compute_batch_fishers(trained_model, variables, sparse_indices, batch):
    n_batch_examples = diagonal.batch_size_from_batch(batch)

    batch_fishers = diagonal.compute_exact_fisher_for_batch(
        batch,
        trained_model,
        variables,
        expectation_wrt_logits=True,
        per_example=True,
    )
    if sparse_indices is not None:
        batch_fishers = [
            tf.vectorized_map(functools.partial(tf.gather_nd, indices=inds), f)
            for f, inds in zip(batch_fishers, sparse_indices)
        ]

    batch_fishers = [
        tf.reshape(f, [n_batch_examples, -1])
        for f in batch_fishers
    ]
    batch_fishers = tf.concat(batch_fishers, axis=-1)
    return batch_fishers


def get_fishers_with_examples(
    trained_model: TFAutoModelForSequenceClassification,
    pe_ds: tf.data.Dataset,
    n_examples: int,
    variables: Sequence[tf.Variable],
    sparse_indices: Optional[Sequence[Union[tf.Tensor, tf.sparse.SparseTensor]]] = None,
    *,
    normalize_fishers: bool = True,
    expectation_wrt_logits: bool = True,
):
    # If sparse_indices are not None, then this function returns the values of a fixed
    # sparse subset of Fisher entries. Otherwise, this returns the dense Fishers.
    if sparse_indices is not None:
        assert len(sparse_indices) == len(variables)
        sparse_indices = [_get_sparse_indices(s) for s in sparse_indices]

    batches = collections.defaultdict(list)
    labels = []

    fishers = []

    n_examples_processed = 0
    for batch, label in pe_ds:
        print('PPPPPJJPJPJPJJPPJ')
        if n_examples_processed >= n_examples:
            break

        for k, v in batch.items():
            batches[k].append(v)
        labels.append(label)

        n_batch_examples = diagonal.batch_size_from_batch(batch)
        batch_fishers = _compute_batch_fishers(trained_model, variables, sparse_indices, batch)
        fishers.append(batch_fishers)

        n_examples_processed += n_batch_examples

    examples = {
        k: tf.concat(v, axis=0)[:n_examples]
        for k, v in batches.items()
    }
    labels = tf.concat(labels, axis=0)[:n_examples]

    fishers = tf.concat(fishers, axis=0)[:n_examples]

    if normalize_fishers:
        fishers = tf.linalg.l2_normalize(fishers, axis=-1)

    return examples, labels, fishers

###############################################################################


def save_per_example_fishers(examples, labels, fishers, filepath: str):
    pass


def load_per_example_fishers(filepath: str):
    pass


###############################################################################

def print_top_examples(W: np.ndarray, tokenizer, examples, labels, component: int, n_examples: int):
    _, inds = tf.math.top_k(W[:, component], k=n_examples)
    for ind in inds:
        label = labels[ind]
        if isinstance(label, tf.Tensor):
            label = label.numpy()
        example = tokenizer.decode(examples['input_ids'][ind])
        example = example.replace(tokenizer.pad_token, '')
        # example = example.replace(tokenizer.bos_token, '')
        # example = example.replace(tokenizer.eos_token, '')
        example = example.strip()
        print(f'{label}: {example}')
