import os
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class DataCollator:
    def __init__(self, tokenizer, pad_to_multiple_of=None):
        self.tokenizer = tokenizer
        self.pad_to_multiple_of = pad_to_multiple_of

    def __call__(self, examples):
        batch = self.torch_call(examples)
        return batch

    def torch_call(self, examples):
        batch = {}

        def _torchfy(data):
            if isinstance(data, np.ndarray):
                return torch.from_numpy(data)
            elif isinstance(data, torch.Tensor):
                return data
            elif isinstance(data, list):
                return torch.tensor(data)
            else:
                raise ValueError(f"Unsupported data type: {type(data)}")

        def _pad_tensor(examples, key, shape, dtype, dim=1, valid_mask=None):
            assert dtype != torch.bool, "bool is not supported (transformers Trainer does not support bool)"
            tensors = []
            for i, example in enumerate(examples):
                tensor = torch.zeros(shape[1:], dtype=dtype)
                data = _torchfy(example[key]).to(dtype)
                data_len = data.size(dim - 1)
                if valid_mask is not None:
                    valid_index = valid_mask[i].nonzero().squeeze()
                    if self.tokenizer.padding_side == "right":
                        tensor.index_add_(dim - 1, valid_index, data)
                    else:
                        tensor.index_add_(dim - 1, valid_index + shape[dim] - data_len, data)
                else:
                    if self.tokenizer.padding_side == "right":
                        tensor.index_add_(dim - 1, torch.tensor(range(data_len)), data)
                    else:
                        tensor.index_add_(dim - 1, torch.tensor(range(-data_len, 0)) + shape[dim], data)
                tensors.append(tensor)

            return torch.stack(tensors, dim=0)

        seq_len = (max([len(ex["input_ids"]) for ex in examples]) + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of
        num_batch = len(examples)

        hidden_dim = examples[0]["hidden_states"].size(-1)
        if examples[0]["output_topk_ids"] is not None:
            repeat = 1
            depth = examples[0]["output_topk_ids"].size(0)
            top_node = examples[0]["output_topk_ids"].size(2)
            top_draft = examples[0]["sampled_scores"].size(2)
            output_topk_probs_dtype = examples[0]["output_topk_probs"].dtype
            sample_logits_dtype = examples[0]["sample_logits"].dtype
        else:
            repeat = examples[0]["sampled_ids"].size(0)
            depth = examples[0]["sampled_ids"].size(1)
            output_topk_probs_dtype = None
            sample_logits_dtype = None

        hidden_state_dtype = examples[0]["hidden_states"].dtype

        batch["input_ids"] = _pad_tensor(examples, "input_ids", (num_batch, seq_len), torch.long, dim=1).detach().clone() if examples[0]["input_ids"] is not None else None
        batch["labels"] = _pad_tensor(examples, "input_ids", (num_batch, seq_len), torch.long, dim=1).detach().clone() if examples[0]["input_ids"] is not None else None
        batch["loss_masks"] = _pad_tensor(examples, "loss_masks", (num_batch, seq_len), torch.long, dim=1).detach().clone() if examples[0]["loss_masks"] is not None else None
        batch["hidden_states"] = _pad_tensor(
            examples,
            "hidden_states",
            (num_batch, seq_len, hidden_dim),
            hidden_state_dtype,
            dim=1,
        ).detach().clone() if examples[0]["hidden_states"] is not None else None
        batch["base_hidden_states"] = _pad_tensor(
            examples,
            "hidden_states",
            (num_batch, seq_len, hidden_dim),
            hidden_state_dtype,
            dim=1,
        ).detach().clone() if examples[0]["hidden_states"] is not None else None
        batch["output_topk_ids"] = _pad_tensor(
            examples,
            "output_topk_ids",
            (num_batch, depth, seq_len, top_node),
            torch.long,
            dim=2,
            valid_mask=batch["loss_masks"],
        ).detach().clone().permute(0, 2, 1, 3) if examples[0]["output_topk_ids"] is not None else None
        batch["output_topk_probs"] = _pad_tensor(
            examples,
            "output_topk_probs",
            (num_batch, depth, seq_len, top_node),
            output_topk_probs_dtype,
            dim=2,
            valid_mask=batch["loss_masks"],
        ).detach().clone().permute(0, 2, 1, 3) if examples[0]["output_topk_probs"] is not None else None
        batch["sample_idxs"] = _pad_tensor(
            examples,
            "sample_idxs",
            (num_batch, depth, seq_len, top_node),
            torch.long,
            dim=2,
            valid_mask=batch["loss_masks"],
        ).detach().clone().permute(0, 2, 1, 3) if examples[0]["sample_idxs"] is not None else None
        batch["sample_logits"] = _pad_tensor(
            examples,
            "sample_logits",
            (num_batch, depth, seq_len, top_node),
            sample_logits_dtype,
            dim=2,
            valid_mask=batch["loss_masks"],
        ).detach().clone().permute(0, 2, 1, 3) if examples[0]["sample_logits"] is not None else None
        batch["sampled_ids"] = _pad_tensor(
            examples,
            "sampled_ids",
            (num_batch, depth, seq_len, top_node),
            torch.long,
            dim=2,
            valid_mask=batch["loss_masks"],
        ).detach().clone().permute(0, 2, 1, 3) if examples[0]["sampled_ids"] is not None else None
        batch["sampled_scores"] = _pad_tensor(
            examples,
            "sampled_scores",
            (num_batch, depth, seq_len, top_draft),
            torch.float,
            dim=2,
            valid_mask=batch["loss_masks"],
        ).detach().clone().permute(0, 2, 1, 3) if examples[0]["sampled_scores"] is not None else None

        return batch


class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            data_dir,
            split="train",
            debug=False,
            add_feature_noise=False,
            tokenizer=None,
            max_length=2048,
    ):
        self.data_dir = data_dir
        self.split = split
        self.add_feature_noise = add_feature_noise
        self.tokenizer = tokenizer
        self.max_length = max_length

        # find all files in the data directory whose prefix matches the split
        self.files = []
        for file in sorted(os.listdir(data_dir)):
            if file.startswith(split):
                self.files.append(data_dir / file)
        self.files = self.files[::-1]

        if debug:
            self.files = self.files[-200:]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = pickle.load(open(self.files[idx], "rb"))

        prompt = None
        if self.tokenizer is not None:
            prompt = self.tokenizer.decode(
                data["input_ids"][:self.max_length],
                skip_special_tokens=False,
            )

        if self.add_feature_noise:
            noise = torch.rand_like(data["hidden_state"]) * 0.2 - 0.1
        else:
            noise = 0

        valid_length = sum([int(i) for i in data["loss_mask"][:self.max_length]])

        return {
            "input_ids": data["input_ids"][:self.max_length],
            "hidden_states": (data["hidden_state"] + noise)[:self.max_length] if "hidden_state" in data else None,
            "base_hidden_states": data["hidden_state"][:self.max_length] if "hidden_state" in data else None,
            "loss_masks": data["loss_mask"][:self.max_length] if "loss_mask" in data else None,
            "prompt": prompt,
            "output_topk_ids": data["output_topk_ids"][:, :valid_length] if "output_topk_ids" in data else None,
            "output_topk_probs": data["output_topk_probs"][:, :valid_length] if "output_topk_probs" in data else None,
            "sample_idxs": data["sample_idxs"][:, :valid_length] if "sample_idxs" in data else None,
            "sample_logits": data["sample_logits"][:, :valid_length] if "sample_logits" in data else None,
            "sampled_ids": data["sampled_ids"][:, :valid_length] if "sampled_ids" in data else None,
            "sampled_scores": data["sampled_scores"][:, :valid_length] if "sampled_scores" in data else None,
        }
