import random
from typing import Any

import numpy
import torch
from tensordict import TensorClass, TensorDict
from data.dataset import AIGRecord


class TrainCollate(torch.nn.Module):
    def __init__(
        self,
        device,
        dtype=torch.float32,
        embedding_size=64,
        train_value=False,
        const_node=True,
        return_action_mask=True,
        get_causal_mask=False,
        negation_prob=0.5,
        permutation_prob=0.5,
    ):
        super().__init__()
        self.device = device
        self.dtype = dtype
        self.embedding_size = embedding_size
        self.const_node = const_node
        self.return_action_mask = return_action_mask
        self.get_causal_mask = get_causal_mask
        self.train_value = train_value
        # self.transform_probs = transform_probs
        self.negation_prob = negation_prob
        self.permutation_prob = permutation_prob
        self.keys = ["nodes", "actions"]
        if return_action_mask:
            self.keys.append("action_mask")
        if get_causal_mask:
            self.keys.append("causal_mask")
        if train_value:
            self.keys.append("reward")
        self.add_target_negation = False

    def __call__(self, batch):
        # move data to deive
        batch = self.batch_data(batch)
        if self.device is not None and self.device.type == "cuda":
            # batch = batch.pin_memory().to(self.device, non_blocking=True)
            batch = batch.contiguous()
        if not self.get_causal_mask:
            batch["causal_mask"] = None  # type: ignore
        return batch

    def batch_data(self, batch):
        return torch.stack([self.prepare_sample(td) for td in batch])

    def prepare_sample(self, batch):
        slice, td = batch

        # bring the data to RAM
        td = td.clone(True).contiguous()

        # Prepare the data
        num_inputs = torch.sym_int(td["num_inputs"])
        num_outputs = torch.sym_int(td["num_outputs"])
        cur_nodes = slice - num_inputs  # type: ignore

        if self.const_node:
            cur_nodes -= 1
        else:
            td["nodes"] = td["nodes"][1:, :]
            td["left_node"].add_(-1)
            td["right_node"].add_(-1)

        # Add the target negation
        if self.add_target_negation:
            td["target"] = torch.cat([td["target"], ~td["target"]], dim=0)
            num_outputs *= 2
        # Apply negation transformation
        elif numpy.random.rand() < self.negation_prob:
            td["target"] = ~td["target"]

        # Slice the data
        # td["nodes"] = torch.cat((td["nodes"][:slice], td["target"]), dim=0)
        td["nodes"] = torch.cat((td["target"], td["nodes"][:slice]), dim=0)

        td["actions"] = td["actions"][:, :slice, :slice].to(self.dtype)
        e = td["edge_type"]
        l = td["left_node"]
        r = td["right_node"]

        # Set the actions already performed to 0
        td["actions"][e[:cur_nodes], l[:cur_nodes], r[:cur_nodes]] = 0

        if self.train_value:
            td["reward"] = td["reward"][cur_nodes]  # type: ignore

        if self.return_action_mask:
            td["action_mask"] = self.create_action_mask(slice)
            td["action_mask"][e[:cur_nodes], l[:cur_nodes], r[:cur_nodes]] = True

        # Create the causal mask in case we want to pass it explicitly to the model
        if self.get_causal_mask:
            td["causal_mask"] = self.create_causal_mask(
                num_inputs,  # type: ignore
                num_outputs,  # type: ignore
                td["nodes"].size(-2),
            )

        # Apply the transform to the AIG
        if numpy.random.rand() < self.negation_prob:
            self.apply_rand_perm(td)

        td["nodes"] = td["nodes"].to(self.dtype)

        return td.select(*self.keys)

    def set_device(self, device):
        self.device = device

    def create_causal_mask(self, num_inputs: int, num_ouputs: int, num_nodes: int):
        diag = num_inputs + num_ouputs
        if not self.const_node:
            diag -= 1
        mask = torch.ones((num_nodes, num_nodes), dtype=torch.bool).tril_(diag)[
            None, :, :
        ]
        return mask

    def create_action_mask(self, num_nodes):
        action_mask = (
            torch.triu(
                torch.ones((num_nodes, num_nodes), dtype=torch.bool),
                diagonal=0,
            ).T
        ).repeat(4, 1, 1)
        return action_mask

    def apply_rand_perm(self, td: TensorDict) -> None:
        rand_perm = torch.randperm(td["nodes"].size(-1))
        td["nodes"] = td["nodes"][:, rand_perm]


class TTbatch(TensorClass):
    normal: torch.Tensor
    negate: torch.Tensor
    norm_float: torch.Tensor
    neg_float: torch.Tensor


class EmbTrainCollate(torch.nn.Module):
    def __init__(self, device, dtype=torch.float32):
        super().__init__()
        self.device = device
        self.dtype = dtype

    def __call__(self, batch_data):
        # move data to deive
        nor_batch = torch.stack(batch_data).unsqueeze(1)
        neg_batch = torch.logical_not(nor_batch)
        float_batch = nor_batch.to(self.dtype)
        neg_float_batch = neg_batch.to(self.dtype)
        batch = TTbatch(
            nor_batch,
            neg_batch,
            float_batch,
            neg_float_batch,
        )
        if self.device is not None and self.device.type == "cuda":
            batch = batch.contiguous().pin_memory()
        return batch


class AIGBatch(TensorClass):
    """Batch class for full-graph AIG data"""
    nodes: torch.Tensor           # Pre-embedded truth tables for each node
    actions: torch.Tensor         # Action matrix (4×N×N)
    causal_mask: torch.Tensor     # Causal mask for attention
    action_mask: torch.Tensor     # Mask for valid actions (lower triangular)
    attention_mask: torch.Tensor  # Mask for valid nodes (to handle padding)


class AIGCollate:
    """
    Collate function that processes full graphs for direct training.
    """

    def __init__(
        self,
        dtype: torch.dtype = torch.bfloat16,
        negation_prob: float = 0.5,
        permutation_prob: float = 0.5,
        return_embeddings: bool = False,
    ):
        self.dtype = dtype
        self.negation_prob = negation_prob
        self.permutation_prob = permutation_prob
        self.num_inputs = 8
        self.node_start = 10
        self.return_embeddings = return_embeddings

    def __call__(self, batch: list[dict[str, Any]]) -> AIGBatch:
        """Process a batch of records with pre-embedded truth tables."""
        # Sort batch by length for more efficient processing
        sorted_batch = sorted(
            batch, key=lambda x: self.num_inputs + x["num_ands"] + 2, reverse=True
        )
        max_length = self.num_inputs + sorted_batch[0]["num_ands"] + 2

        # Process each sample in the batch
        processed_samples = [
            self._prepare_sample(sample, max_length)
            for sample in sorted_batch
        ]
        
        return torch.stack(processed_samples).contiguous()
        
    def _prepare_sample(self, sample: dict[str, Any], max_length: int) -> AIGBatch:
        """Process a single sample with pre-embedded truth tables."""
        # Extract basic info
        num_inputs = self.num_inputs
        num_ands = sample["num_ands"]
        total_nodes = num_inputs + num_ands + 2  # Add 1 for constant node and 1 for target node

        # Extract embedded tensors - use pre-computed embeddings
        if self.return_embeddings:
            # Use pre-embedded truth tables
            target_tt = torch.from_numpy(sample["target_embedding"].copy())
            nodes_tt = torch.from_numpy(numpy.vstack(sample["truth_table_embeddings"].copy()))
        else:
            target_tt = torch.from_numpy(sample["target"].copy())
            nodes_tt = torch.from_numpy(numpy.vstack(sample["truth_tables"].copy()))
            
            if random.random() < self.negation_prob:
                target_tt = ~target_tt

            if random.random() < self.permutation_prob:
                rand_perm = torch.randperm(nodes_tt.size(-1))
                target_tt = target_tt[rand_perm]
                nodes_tt = nodes_tt[:, rand_perm]

            target_tt = target_tt.to(self.dtype)
            nodes_tt = nodes_tt.to(self.dtype)
            
        nodes = torch.cat(
            [
                target_tt.unsqueeze(0),  # Target embedding
                nodes_tt,                # Truth table embeddings
                torch.zeros(             # Padding
                    max_length - total_nodes, 
                    target_tt.size(-1), 
                    dtype=self.dtype
                ),
            ],
            dim=0,
        )

        # Build action tensor (4 edge types × nodes × nodes)
        actions = torch.zeros((4, max_length-1, max_length-1), dtype=self.dtype)

        # Extract edge information
        edges_type_idx = torch.from_numpy(sample["edges_type_idx"].copy()).to(dtype=torch.long)
        left_parent_idx = torch.from_numpy(sample["left_parent_idx"].copy()).to(dtype=torch.long)
        right_parent_idx = torch.from_numpy(sample["right_parent_idx"].copy()).to(dtype=torch.long)
        
        # Set actions based on edges
        actions[edges_type_idx, left_parent_idx, right_parent_idx] = 1.0

        # Create attention mask (True for real nodes, False for padding)
        attention_mask = torch.tensor(
            [True] * total_nodes + [False] * (max_length - total_nodes),
            dtype=torch.bool,
        )

        # Create causal mask
        causal_mask = self._create_causal_mask(num_inputs, 1, max_length)
        # causal_mask = torch.ones_like(causal_mask, dtype=torch.bool) #TODO: remove this line!!!

        # Create action mask as simple lower triangular matrix
        # True where action is allowed (we'll invert it later in the loss function)
        action_mask = torch.ones_like(actions, dtype=torch.bool).triu_(1) & attention_mask[None, None, 1:]
        
        return AIGBatch(
            nodes=nodes.to(self.dtype),
            actions=actions,
            causal_mask=causal_mask,
            action_mask=action_mask,
            attention_mask=attention_mask,
        )

    def _create_causal_mask(
        self, num_inputs: int, num_outputs: int, num_nodes: int
    ) -> torch.Tensor:
        """Create causal attention mask."""
        diag = num_inputs + num_outputs - 1
        mask = torch.ones((1, num_nodes, num_nodes), dtype=torch.bool).tril_(diag)
        return mask


class AIGSliceCollate:
    """
    Collate function for AIG slices.
    """

    def __init__(
        self,
        dtype: torch.dtype = torch.bfloat16,
        negation_prob: float = 0.5,
        permutation_prob: float = 0.5,
    ):
        self.dtype = dtype
        self.negation_prob = negation_prob
        self.permutation_prob = permutation_prob
        self.num_inputs = 8

    def __call__(self, batch: list[tuple[AIGRecord, int]]) -> AIGBatch:
        sorted_batch = sorted(batch, key=lambda x: x[1], reverse=True)
        max_num_ands = sorted_batch[0][1]

        processed_samples = [
            self._prepare_sample(record, num_ands, max_num_ands)
            for record, num_ands in sorted_batch
        ]
        
        return torch.stack(processed_samples, dim=0).contiguous()

    def _prepare_sample(
        self,
        sample: AIGRecord,
        num_ands: int,
        max_num_ands: int
    ) -> AIGBatch:
        """Process a single sample with pre-embedded truth tables."""
        # Extract basic info

        # Apply data augmentation
        target_tt = sample.target
        if random.random() < self.negation_prob:
            target_tt = ~target_tt

        nodes_tt = sample.truth_tables
        if random.random() < self.permutation_prob:
            rand_perm = torch.randperm(nodes_tt.size(-1))
            target_tt = target_tt[:, rand_perm]
            nodes_tt = nodes_tt[:, rand_perm]

        target_num_nodes = self.num_inputs + 1 + num_ands
        common_num_nodes = self.num_inputs + 1 + max_num_ands
            
        nodes = torch.cat(
            [
                target_tt.to(self.dtype),
                nodes_tt[:target_num_nodes, :].to(self.dtype),
                torch.zeros(            
                    (
                        max_num_ands - num_ands, 
                        target_tt.size(-1), 
                    ),
                    dtype=self.dtype
                ),
            ],
            dim=0,
        )


        # Extract edge information
        edges_type_idx = sample.edges_type_idx.to(dtype=torch.long)
        left_parent_idx = sample.left_parent_idx.to(dtype=torch.long)
        right_parent_idx = sample.right_parent_idx.to(dtype=torch.long)

        # Build action tensor (4 edge types × nodes × nodes)
        max_nodes = max(common_num_nodes, nodes_tt.size(-2))
        actions = torch.zeros((4, max_nodes, max_nodes), dtype=self.dtype)
        
        # Set actions based on edges
        actions[
            edges_type_idx[num_ands:],
            left_parent_idx[num_ands:],
            right_parent_idx[num_ands:],
        ] = 1.0

        # Create action mask as simple lower triangular matrix
        # True where action is allowed (we'll invert it later in the loss function)
        action_mask = torch.ones_like(actions, dtype=torch.bool).triu_(1)
        action_mask[
            edges_type_idx[:num_ands],
            left_parent_idx[:num_ands],
            right_parent_idx[:num_ands],
        ] = False

        # Chop to the appropriate size
        actions = actions[:, :common_num_nodes, :common_num_nodes]
        action_mask = action_mask[:, :common_num_nodes, :common_num_nodes]


        # Create attention mask (True for real nodes, False for padding)
        attention_mask = torch.tensor(
            [True] * (target_num_nodes + 1) + [False] * (max_num_ands - num_ands),
            dtype=torch.bool,
        )

        action_mask = action_mask & attention_mask[None, None, 1:]

        # Create causal mask
        causal_mask = self._create_causal_mask(
            self.num_inputs, 1,
            nodes.size(-2)
        )

        return AIGBatch(
            nodes=nodes,
            actions=actions,
            causal_mask=causal_mask,
            action_mask=action_mask,
            attention_mask=attention_mask,
        )

    def _create_causal_mask(
        self, num_inputs: int, num_outputs: int, num_nodes: int
    ) -> torch.Tensor:
        """Create causal attention mask."""
        diag = num_inputs + num_outputs - 1
        mask = torch.ones((1, num_nodes, num_nodes), dtype=torch.bool).tril_(diag)
        return mask