import copy
import glob
import hashlib
import logging
import math
import os
import random
import tempfile
from collections import OrderedDict
from collections.abc import Iterable
from typing import Any, Dict, List, Tuple

import torch
from numpy.random import choice
from omegaconf import MISSING
from tensordict import LazyStackedTensorDict, TensorDict
from torch.multiprocessing import Process
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import BatchSampler, Sampler
from torchrl.objectives.value.functional import reward2go
from tqdm import tqdm

from . import pyaig
from .pyaig.aig_env import AIGEnv


class TT_Dataset(Dataset):
    def __init__(self, path: str | list[str], embedding_size: int) -> None:
        self.dataset = []
        self.embedding_size = embedding_size

        if isinstance(path, Iterable):
            tmp_paths = []
            for p in path:
                if "*" in path:
                    tmp_paths += glob.glob(p, recursive=True)
                else:
                    tmp_paths.append(path)
            path = tmp_paths
        else:
            if "*" in path:
                path = glob.glob(path, recursive=True)
            else:
                path = [path]
        path = list(set(path))
        tts = []
        for p in path:
            f = open(p, "r")
            tts += [tt[:-1] for tt in f.readlines()]
            f.close()
        tts = list(set(tts))
        for tt in tqdm(tts):
            scale = self.embedding_size // len(tt)
            if len(tt) * scale != self.embedding_size:
                raise ValueError(
                    f"Truth-table's length {len(tt)} cannot scale to the embedding size {self.embedding_size}"
                )
            self.dataset.append(
                TensorDict(
                    {
                        "num_inputs": torch.tensor(
                            [int(math.log2(len(tt)))], dtype=torch.int32
                        ),
                        "target": self._str_to_tensor(tt).repeat(scale).unsqueeze(0),
                    },
                    batch_size=[],
                )
            )

    def _str_to_tensor(self, tt: str) -> torch.Tensor:
        scale = self.embedding_size // len(tt)
        if len(tt) * scale != self.embedding_size:
            raise ValueError(
                f"Truth-table's length {len(tt)} cannot scale to the embedding size {self.embedding_size}"
            )
        return torch.tensor(
            [bit == "1" for bit in tt], dtype=torch.bool, requires_grad=False
        ).repeat(scale)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> torch.Tensor:
        return self.dataset[index]

    def get_slice(self, rank: int, num_collectors: int) -> "TT_Dataset":
        new_dt = TT_Dataset([], self.embedding_size)
        new_dt.dataset = self.dataset[rank % num_collectors :: num_collectors]
        return new_dt


class AIG_Dataset(Dataset):
    def __init__(
        self,
        aig: str | pyaig.Learned_AIG,
        embedding_size: int,
        return_action_mask: bool = False,
        reward_type: str = "simple",
        const_node: bool = True,
        gamma: float = 0.99,
    ) -> None:

        if isinstance(aig, str):
            self.aig = pyaig.Learned_AIG.read_aig(aig)
        else:
            self.aig = aig

        self.n_pis = self.aig.n_pis()
        self.n_pos = self.aig.n_pos()
        self.embedding_size = -1
        self.nodes = torch.empty(0)
        self.target = torch.empty(0)
        self.actions = torch.empty(0)
        self.edge_type_idx = torch.empty(0)
        self.left_parent_idx = torch.empty(0)
        self.right_parent_idx = torch.empty(0)
        self.update_embedding_size(embedding_size)
        assert self.nodes.numel() > 0

        if reward_type == "shaped":
            if const_node:
                final_reward = AIGEnv._shaped_reward_function_const(
                    self.nodes, torch.tensor([self.n_pis])
                )
            else:
                final_reward = AIGEnv._shaped_reward_function_noconst(
                    self.nodes, torch.tensor([self.n_pis])
                )
        else:
            final_reward = AIGEnv._simple_reward_function(
                self.nodes, torch.tensor([self.n_pis])
            )

        self.reward = torch.zeros(self.aig.n_ands(), dtype=torch.float32)
        self.reward[-1] = final_reward
        _done = torch.zeros_like(self.reward, dtype=torch.bool)
        _done[-1] = True
        self.reward = reward2go(self.reward, _done, gamma, time_dim=-1)
        self.return_action_mask = return_action_mask
        self.action_mask = (
            torch.triu(
                torch.ones(
                    (
                        self.n_pis + self.aig.n_ands() + 1,
                        self.n_pis + self.aig.n_ands() + 1,
                    ),
                    dtype=torch.bool,
                ),
                diagonal=0,
            ).T
        ).repeat(4, 1, 1)

    def update_embedding_size(self, embedding_size):
        if embedding_size != self.embedding_size:
            self.embedding_size = embedding_size
            (
                self.nodes,
                self.target,
                self.actions,
                self.edge_type_idx,
                self.left_parent_idx,
                self.right_parent_idx,
            ) = self.aig.prepare_data(embedding_size)

    def get_data(self):
        aig_data = TensorDict(
            {
                "nodes": self.nodes,
                "target": self.target,
                "reward": self.reward,
                "actions": self.actions,
                "edge_type": self.edge_type_idx,
                "left_node": self.left_parent_idx,
                "right_node": self.right_parent_idx,
                # "action_mask": self.action_mask,
                "num_inputs": torch.tensor(self.n_pis, dtype=torch.int32),
                "num_outputs": torch.tensor(self.n_pos, dtype=torch.int32),
                "num_ands": torch.tensor(self.aig.n_ands(), dtype=torch.int32),
            },
            batch_size=[],
        )
        return aig_data

    def __len__(self) -> int:
        return self.aig.n_ands()

    def __getitem__(self, index: int) -> TensorDict:
        num_nodes = self.n_pis + index + 1
        subset_nodes = self.nodes[
            :num_nodes,
        ]
        subset_actions = self.actions[:, :num_nodes, :num_nodes].clone()
        subset_actions[
            self.edge_type_idx[:index],
            self.left_parent_idx[:index],
            self.right_parent_idx[:index],
        ] = 0

        td = TensorDict(
            {
                "nodes": torch.cat((subset_nodes, self.target)),
                "actions": subset_actions,
                "reward": self.reward[index],
            },
            batch_size=[],
        )

        if self.return_action_mask:
            subset_action_mask = self.action_mask[:, :num_nodes, :num_nodes].clone()
            subset_action_mask[
                self.edge_type_idx[:index],
                self.left_parent_idx[:index],
                self.right_parent_idx[:index],
            ] = True
            td["action_mask"] = subset_action_mask

        return td


class AIG_Dataset_Collection(Dataset):
    def __init__(
        self,
        aigs: list[str | pyaig.Learned_AIG | AIG_Dataset] | str,
        embedding_size: int,
        return_action_mask: bool = False,
        reward_type: str = "simple",
        gamma: float = 0.99,
        const_node: bool = True,
        cache_path: str = "data/cache",
        fragments: int = 1,
        num_workers: int = 4,
        rebuild_cache: bool = False,
        proc_id: int | None = None,
    ) -> None:
        """_summary_

        Args:
            aigs (list[str  |  pyaig.Learned_AIG  |  AIG_Dataset] | str): _description_
            embedding_size (int): _description_
            return_action_mask (bool, optional): _description_. Defaults to False.
            reward_type (str, optional): _description_. Defaults to "simple".
            gamma (float, optional): _description_. Defaults to 0.99.
            const_node (bool, optional): _description_. Defaults to True.
            cache_path (str, optional): _description_. Defaults to "data/cache".
            fragments (int, optional): _description_. Defaults to 1.
            num_workers (int, optional): _description_. Defaults to 4.
        """
        # Dataset parameters
        self.return_action_mask = return_action_mask
        self.embedding_size = embedding_size
        self.const_node = const_node
        self.reward_type = reward_type
        self.gamma = gamma
        self.max_seq_len = 0

        # Bucket management
        self.buckets = {}
        self.bucket_keys = []
        self.bucket_sizes: List[int] = []
        self.int2bucket_record = []

        # Data storage management
        self.key_list: List[
            Tuple[str, int]
        ] = []  # a flat view of the data (int map key, record_idx)
        self.key_map: Dict[
            str, int
        ] = {}  # maps an int key to the graph_type i.e. num_inputs-num_ands
        self.data: TensorDict

        # Skip dataloading for training or evaluation datasets
        if len(aigs) == 0:
            return

        log = logging.getLogger(__name__)

        aig_paths = []
        aig_paths = self._decode_path(aigs)

        log.info(f"Total number of AIGs: {len(aig_paths)}")

        file_hash = self._hash_file(aig_paths)

        cache_path = os.path.join(cache_path, file_hash)
        self.cache_path = cache_path

        if rebuild_cache:
            if os.path.isdir(cache_path):
                log.info(f"Removing cache: {cache_path}")
                os.system(f"rm -rf {cache_path}")

        # There is not cache, the dataset needs to be loaded
        if rebuild_cache or not os.path.isdir(cache_path):
            # Load the dataset with more than 1 workers
            if num_workers > 1:
                # if False:
                self._parallel_load(aig_paths, fragments, num_workers)

            # The dataset will be loaded in the main process
            else:
                tmp_data = []
                desc = ""
                if proc_id is not None:
                    desc += f"[Worker:{proc_id:3}]"
                desc += "[Reading AIG files]"
                for aig in tqdm(aig_paths, desc=desc, position=proc_id):
                    if isinstance(aig, str) or isinstance(aig, pyaig.Learned_AIG):
                        aig_dt = AIG_Dataset(
                            aig,
                            embedding_size,
                            self.return_action_mask,
                            self.reward_type,
                            self.const_node,  # maybe not needed
                            self.gamma,
                        )
                        aig_data = aig_dt.get_data()

                    elif isinstance(aig, AIG_Dataset):
                        aig.return_action_mask = self.return_action_mask
                        aig.update_embedding_size(embedding_size)
                        aig_data = aig.get_data()

                    tmp_data.append(aig_data)

                self.data = self.save(tmp_data, cache_path)
        else:
            log.info(f"Loading dataset from cache: {cache_path}")
            self.data = self.load(cache_path)
        # self.data.share_memory_()

    def _parallel_load(self, aigs_paths: list[str], fragments: int, num_workers: int):
        processes = []
        chunk_size = math.ceil(len(aigs_paths) / num_workers)
        temp_dir = tempfile.TemporaryDirectory()
        temp_dir_name = temp_dir.name

        fragment_size = math.ceil(len(aigs_paths) / fragments)
        for j in range(fragments):
            fragment_aigs = aigs_paths[j * fragment_size : (j + 1) * fragment_size]
            chunk_size = math.ceil(len(fragment_aigs) / num_workers)
            for i in range(num_workers):
                p = Process(
                    target=multiprocessing_load_dataset,
                    args=(
                        fragment_aigs[i * chunk_size : (i + 1) * chunk_size],
                        self.embedding_size,
                        self.return_action_mask,
                        self.reward_type,
                        self.gamma,
                        self.const_node,
                        temp_dir_name,
                        i,
                    ),
                )
                p.start()
                processes.append(p)

            for p in processes:
                p.join()

        # allocate disk space
        tmp_data = {}
        metadata = {}
        for tmp_dir in [f.path for f in os.scandir(temp_dir_name) if f.is_dir()]:
            dt_data = TensorDict.load_memmap(tmp_dir)  # type: ignore
            for key in dt_data.keys():
                if key not in metadata:
                    metadata[key] = []
                    tmp_data[key] = []
                metadata[key].append(dt_data[key].to("meta"))
                tmp_data[key].append(dt_data[key])

        for key in list(metadata.keys()):
            metadata[key] = torch.cat(metadata[key], dim=0)

        os.makedirs(self.cache_path)
        cache = TensorDict(metadata, batch_size=[]).memmap_like(
            self.cache_path, num_threads=32
        )
        cache = cache.to("cpu")

        # write data to disk
        for key in tqdm(list(tmp_data.keys()), desc="[Merging cached data]"):
            torch.cat(tmp_data[key], out=cache[key], dim=0)

        for i, key in enumerate(cache.keys()):
            self.key_map[key] = i
            size = len(cache[key])
            self.key_list.extend(zip((key,) * size, range(size)))

        self.data = cache

    def _decode_path(self, aigs: list[str | pyaig.Learned_AIG | AIG_Dataset] | str):
        if isinstance(aigs, str):
            if "*" in aigs:
                aigs_paths = glob.glob(aigs, recursive=True)
            else:
                aigs_paths = [aigs]

        elif isinstance(aigs, Iterable):
            tmp = []
            for aig in aigs:
                # assert aig is str
                if isinstance(aig, str):
                    if "*" in aig:
                        tmp += glob.glob(aig, recursive=True)
                    else:
                        tmp.append(aig)
                else:
                    tmp.append(aig)
            aigs_paths = tmp
        return aigs_paths

    def _hash_file(self, aigs: list[str]) -> str:
        text = "".join(
            sorted(aigs)
            + [
                str(self.const_node)
                + str(self.embedding_size)
                + str(self.return_action_mask)
                + self.reward_type
                + str(self.gamma)
            ]
        )

        file_hash = str(hashlib.sha1(text.encode("utf-8")).hexdigest())
        return file_hash

    def initialize_buckets(self):
        for graph_size, record_idx in self.key_list:
            aig_data = self.get_record(graph_size, record_idx)
            num_inputs = torch.sym_int(aig_data["num_inputs"])  # type: ignore
            num_ands = torch.sym_int(aig_data["num_ands"])  # type: ignore
            if num_inputs + num_ands + 1 > self.max_seq_len:  # type: ignore
                self.max_seq_len = num_inputs + num_ands + 1  # type: ignore
            for i in range(num_ands):
                idx = num_inputs + 1 + i
                if idx not in self.buckets:
                    self.buckets[idx] = []
                self.buckets[idx].append((graph_size, record_idx))

        self.bucket_keys = list(self.buckets.keys())
        self.bucket_sizes = []
        for key in self.bucket_keys:
            self.bucket_sizes.append(len(self.buckets[key]))
            self.int2bucket_record.extend(
                zip((key,) * len(self.buckets[key]), range(len(self.buckets[key])))
            )

        self.int2bucket_record.sort()

    def get_record(self, graph_size: str, idx: int):
        """Retruns the record in the memory mapped tensordict

        Args:
            graph_size (str): a string representing the graph num_inputs-num_ands
            idx (int): index of the record in the bucket of the stacked tensordict

        Returns:
            TensorDict: the graph stored in the memory mapped tensordict
        """
        return self.data[graph_size][idx]

    def save(self, data, cache_path):
        stacks = OrderedDict()
        for aig_idx in range(len(data)):
            aig_data = data[aig_idx]
            num_inputs = str(torch.sym_int(aig_data["num_inputs"]))
            num_ands = str(torch.sym_int(aig_data["num_ands"]))
            key = num_inputs + "-" + num_ands
            if key not in stacks:
                stacks[key] = []
            stacks[key].append(aig_data)

        cache = TensorDict({}, batch_size=[])
        self.key_list = []
        for i, key in enumerate(stacks.keys()):
            self.key_map[key] = i
            size = len(stacks[key])
            self.key_list.extend(zip((key,) * size, range(size)))
            cache[key] = stacks[key][0].clone(False).to("meta").expand(size)

        os.makedirs(cache_path)
        cache = cache.memmap_like(cache_path, num_threads=16)

        for key in self.key_map.keys():
            torch.stack(stacks[key], out=cache[key])
            del stacks[key]
        return cache

    def load(self, cache_path):
        cache = TensorDict.load_memmap(cache_path)
        for i, key in enumerate(cache.keys()):
            self.key_map[key] = i
            size = len(cache[key])
            self.key_list.extend(zip((key,) * size, range(size)))
        return cache

    def __len__(self) -> int:
        return sum(self.bucket_sizes)

    def __getitem__(self, idx_tuple: Tuple[int, int]):
        slice, buck_idx = idx_tuple
        graph_size, record_idx = self.buckets[slice][buck_idx]
        data = self.get_record(graph_size, record_idx)

        return (slice, data)

    def split_data(
        self,
        split_ratio: float,
        seed: int | None = None,
    ) -> Tuple["AIG_Dataset_Collection", "AIG_Dataset_Collection"]:

        eval_size = int(len(self.key_list) * split_ratio)
        data_keys = copy.copy(self.key_list)
        random.Random(seed).shuffle(data_keys)

        eval_key_list = data_keys[:eval_size]
        train_key_list = data_keys[eval_size:]

        eval_dataset = AIG_Dataset_Collection(
            "",
            self.embedding_size,
            self.return_action_mask,
            self.reward_type,
            self.gamma,
            self.const_node,
        )
        eval_dataset.data = self.data
        eval_dataset.key_list = eval_key_list
        eval_dataset.key_map = self.key_map

        train_dataset = AIG_Dataset_Collection(
            "",
            self.embedding_size,
            self.return_action_mask,
            self.reward_type,
            self.gamma,
            self.const_node,
        )
        train_dataset.data = self.data
        train_dataset.key_list = train_key_list
        train_dataset.key_map = self.key_map

        eval_dataset.initialize_buckets()
        train_dataset.initialize_buckets()

        return train_dataset, eval_dataset

    def get_generation_data(self, sample_size: int) -> List[TensorDict]:
        tt = []
        rand_idxs = copy.copy(self.key_list)
        random.shuffle(rand_idxs)
        for graph_size, record_idx in rand_idxs[:sample_size]:
            aig_data = self.get_record(graph_size, record_idx)
            td = TensorDict(
                {
                    "num_inputs": aig_data["num_inputs"].unsqueeze(0),  # type: ignore
                    "num_outputs": aig_data["num_outputs"].unsqueeze(0),  # type: ignore
                    "target": aig_data["target"],  # type: ignore
                },
                batch_size=[],
            )
            td.share_memory_()
            tt.append(td)
        return tt


class AIGBatchSampler(BatchSampler):
    def __init__(self, dataset: AIG_Dataset_Collection, batch_size: int):
        self.batch_size = batch_size
        self.dataset = dataset
        self.bucket_keys = self.dataset.bucket_keys
        self.bucket_sizes = self.dataset.bucket_sizes
        self.size = 0
        for sz in self.bucket_sizes:
            self.size += sz // self.batch_size

    def __iter__(self):
        random_list = []

        for i, slice in enumerate(self.bucket_keys):
            rand_idxs = list(range(self.bucket_sizes[i]))
            random.shuffle(rand_idxs)
            for j in range(self.bucket_sizes[i] // self.batch_size):
                idxs = rand_idxs[j * self.batch_size : (j + 1) * self.batch_size]
                random_list.append((slice, idxs))

        random.shuffle(random_list)

        for slice, idxs in random_list:
            yield list(zip((slice,) * len(idxs), idxs))

    def __len__(self):
        return self.size


class EvalAIGBatchSampler(BatchSampler):
    def __init__(self, dataset: AIG_Dataset_Collection, batch_size: int):
        self.batch_size = batch_size
        self.dataset = dataset
        self.bucket_keys = self.dataset.bucket_keys
        self.bucket_sizes = self.dataset.bucket_sizes
        self.size = 0
        for sz in self.bucket_sizes:
            self.size += sz // self.batch_size

    def __iter__(self):
        for i, key in enumerate(self.bucket_keys):
            idxs = torch.arange(0, self.bucket_sizes[i])
            for j in range(self.bucket_sizes[i] // self.batch_size):
                yield list(
                    zip(
                        (int(key),) * self.batch_size,
                        idxs[j * self.batch_size : (j + 1) * self.batch_size],
                    )
                )

    def __len__(self):
        return self.size


def multiprocessing_load_dataset(
    aigs: list[str | pyaig.Learned_AIG | AIG_Dataset] | str,
    embedding_size: int,
    return_action_mask: bool = False,
    reward_type: str = "simple",
    gamma: float = 0.99,
    const_node: bool = True,
    cache_path: str = "data/cache",
    proc_id: int = 0,
):
    dts = AIG_Dataset_Collection(
        aigs=aigs,
        embedding_size=embedding_size,
        return_action_mask=return_action_mask,
        reward_type=reward_type,
        gamma=gamma,
        const_node=const_node,
        cache_path=cache_path,
        fragments=1,
        num_workers=0,
        rebuild_cache=False,
        proc_id=proc_id,
    )
    del dts


class DistributedBatchSampler(DistributedSampler):
    def __init__(
        self,
        dataset: Dataset,
        num_replicas: int | None = None,
        rank: int | None = None,
        shuffle: bool = True,
        seed: int = 0,
        drop_last: bool = False,
        batch_size: int = 10,
    ) -> None:
        super().__init__(
            dataset=dataset,
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
            seed=seed,
            drop_last=drop_last,
        )
        self.batch_size = batch_size

    def __iter__(self):
        indices = list(super().__iter__())
        batch_sampler = DistributedBatchSamplerHelper(
            self.dataset, batch_size=self.batch_size, indices=indices  # type: ignore
        )
        return iter(batch_sampler)

    def __len__(self) -> int:
        return self.num_samples // self.batch_size


class DistributedBatchSamplerHelper(Sampler):
    def __init__(
        self,
        dataset: AIG_Dataset_Collection,
        batch_size: int,
        indices: List[int] | None = None,
        shuffle: bool = True,
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = []

        self.buckets = {}
        if indices is not None:
            self.indices = indices
            for idx in indices:
                slice, rec_idx = self.dataset.int2bucket_record[idx]
                if slice not in self.buckets:
                    self.buckets[slice] = []
                self.buckets[slice].append((slice, rec_idx))

            bucket_keys = list(self.buckets.keys())
            bucket_probs = [len(self.buckets[b]) / len(indices) for b in bucket_keys]
            self.bucket_choices = choice(
                bucket_keys,
                len(indices) // self.batch_size,
                p=bucket_probs,
                replace=True,
            )

    def __iter__(self):
        for idx in range(len(self)):
            bucket = self.buckets[self.bucket_choices[idx]]
            replace = len(bucket) < self.batch_size
            batch_idx = choice(
                len(bucket),
                self.batch_size,
                replace=replace,
            )
            yield [bucket[i] for i in batch_idx]

    def __len__(self):
        return len(self.indices) // self.batch_size
