from tqdm import tqdm
import os

"""
This is not a usual utils script.
It contains all of the stuff that XXXX-1 needs to do the GPT neo world knowledge experiment.
"""


def get_indices_where_minumum_n_tokens(
    dataset,
    tokenizer,
    num_min_token_ids=10,
    num_max_token_ids=20,
    max_num_samples=10_000,
    split="train",
):
    # num_token_ids_all = []
    indices_where_len_is_big = []

    for i in tqdm(
        range(len(dataset[split])),
        desc=f"Filtering dataset to find samples of lengths between: {num_min_token_ids} and {num_max_token_ids} tokens",
    ):
        text = dataset[split][i]["text"]

        token_ids = tokenizer.encode(text)
        # num_token_ids_all.append(len(token_ids))

        if len(token_ids) >= num_min_token_ids and len(token_ids) <= num_max_token_ids:
            indices_where_len_is_big.append(i)

        if max_num_samples is not None:
            if i > max_num_samples:
                break

    return indices_where_len_is_big


class NiceWikiTextDataset:
    def __init__(self, dataset, indices):
        """
        indices would be filtered,
        we would run only on examples with > 10 tokens or so
        """
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset["train"][self.indices[idx]]


import torch.nn as nn
from typing import List
from .hook import ForwardHook


class Sandbox:
    def __init__(self, tokenizer, model, device_ids, forward_hooks: List[ForwardHook]):
        self.tokenizer = tokenizer
        self.device_ids = device_ids
        self.model = nn.DataParallel(model, device_ids=device_ids)

        self.forward_hooks = forward_hooks

    def get_ffn_hook_output_for_batch(self, batch, device):
        """
        runs a forward pass through the model and returns FFN hook outputs
        """
        tokenized_text = self.tokenizer(
            batch["text"], padding=True, return_tensors="pt"
        )

        sequence_lengths = [len(self.tokenizer.encode(x)) for x in batch["text"]]

        logits = self.model.forward(
            input_ids=tokenized_text["input_ids"].to(device),
            attention_mask=tokenized_text["attention_mask"].to(device),
        ).logits

        return {
            "logits": logits,
            "input_ids": tokenized_text["input_ids"],
            "attention_mask": tokenized_text["attention_mask"],
            "sequence_lengths": sequence_lengths,
            "hook_outputs": [x.output for x in self.forward_hooks],
        }


class ResultFolder:
    def __init__(self, folder: str, num_items=None, filtered_indices=None):
        self.folder = folder
        if num_items is None:
            self.num_items = len(os.listdir(self.folder))
        else:
            self.num_items = num_items

        if filtered_indices is None:
            self.filtered_indices = [i for i in range(self.num_items)]
        else:
            self.filtered_indices = filtered_indices

    def __getitem__(self, idx):
        return load_json_as_dict(
            filename=os.path.join(
                self.folder, f"actual_dataset_idx_{self.filtered_indices[idx]}.json"
            )
        )

    def __len__(self):
        return self.num_items
