from typing import overload

import torch
from sentence_transformers import SentenceTransformer
from torch.nn.utils.rnn import pad_sequence
from trl import DataCollatorForCompletionOnlyLM

from mow.common.graph import (
    edge_pruning,
    map_observation_to_graph,
    node_pruning,
    pad_adjacency_matrix,
)
from mow.common.tensor import merge_views


def indexing(example, dataset_idx: int):
    return {
        **example,
        "labels": dataset_idx,
    }


def prepare_graph_representation(
    instruction: str | list[str],
    observation: str | list[str],
    labels: int | None,
    sentence_transformer: SentenceTransformer,
    extra_context: str | list[str] | None = None,
):
    instruction = [instruction] if isinstance(instruction, str) else instruction
    nodes, adj, rels = map_observation_to_graph(observation)

    with torch.no_grad():
        # (num_inst, embed_dim)
        inst = sentence_transformer.encode(instruction, convert_to_tensor=True)
        inst = inst.sum(dim=0, keepdim=True)  # (1, embed_dim)

        if extra_context is not None:
            # (num_ctx, embed_dim)
            ctx = sentence_transformer.encode(
                extra_context, convert_to_tensor=True
            )
            # (1 + num_ctx, embed_dim)
            ctx = torch.concat([inst, ctx], dim=0)
        else:
            ctx = inst  # (1, embed_dim)

        # (num_nodes, embed_dim)
        nodes = sentence_transformer.encode(nodes, convert_to_tensor=True)

        # (num_rels, embed_dim)
        rels = sentence_transformer.encode(rels, convert_to_tensor=True)
        # Set 0 to zero index
        rels[0, ...] = torch.zeros(
            (1, *rels.shape[1:]), dtype=rels.dtype, device=rels.device
        )

    ret = {
        "context": ctx,
        "nodes": nodes,
        "adjacency_matrix": adj,
        "relation_matrix": rels,
    }

    if labels is not None:
        label = torch.tensor(labels, dtype=torch.long)
        ret["labels"] = label

    return ret


def prepare_batch_data(
    batch,
    graph_augmentation: bool = False,
    data_collator_lm: DataCollatorForCompletionOnlyLM | None = None,
) -> dict[str, torch.Tensor]:
    if isinstance(batch[0], dict):
        context = [x["context"] for x in batch]
        hidden_states = [x["nodes"] for x in batch]
        adjacency_matrix = [x["adjacency_matrix"] for x in batch]
        relation_matrix = [x["relation_matrix"] for x in batch]
        labels = [x["labels"] for x in batch] if "labels" in batch[0] else None
    else:
        context = batch["context"]
        hidden_states = batch["nodes"]
        adjacency_matrix = batch["adjacency_matrix"]
        relation_matrix = batch["relation_matrix"]
        labels = batch["labels"] if "labels" in batch else None

    # (batch_size, num_context, embed_dim)
    context = pad_sequence(context, batch_first=True)

    # (batch_size, num_nodes, embed_dim)
    hidden_states = pad_sequence(hidden_states, batch_first=True)

    # (batch_size, num_nodes, num_nodes)
    adjacency_matrix = pad_adjacency_matrix(adjacency_matrix)

    # (batch_size, num_rels, embed_dim)
    relation_matrix = pad_sequence(relation_matrix, batch_first=True)

    labels = torch.stack(labels, dim=0) if labels is not None else None

    if graph_augmentation:
        hidden_states_1, adjacency_matrix_1 = node_pruning(
            hidden_states=hidden_states,
            adjacency_matrix=adjacency_matrix,
        )

        hidden_states_2, adjacency_matrix_2 = hidden_states, edge_pruning(
            adjacency_matrix=adjacency_matrix,
        )

        hidden_states_1 = hidden_states_1.clone().detach()
        hidden_states_2 = hidden_states_2.clone().detach()
        hidden_states = merge_views(hidden_states_1, hidden_states_2)

        adjacency_matrix_1 = adjacency_matrix_1.clone().detach()
        adjacency_matrix_2 = adjacency_matrix_2.clone().detach()
        adjacency_matrix = merge_views(adjacency_matrix_1, adjacency_matrix_2)

        relation_matrix_1 = relation_matrix.clone().detach()
        relation_matrix_2 = relation_matrix.clone().detach()
        relation_matrix = merge_views(relation_matrix_1, relation_matrix_2)

        context_1 = context.clone().detach()
        context_2 = context.clone().detach()
        context = merge_views(context_1, context_2)

        if labels is not None:
            labels_1 = labels.clone().detach()
            labels_2 = labels.clone().detach()
            labels = merge_views(labels_1, labels_2)

    ret = {
        "context": context,
        "hidden_states": hidden_states,
        "adjacency_matrix": adjacency_matrix,
        "relation_matrix": relation_matrix,
        "labels": labels,
    }

    if "input_ids" in batch[0]:
        if data_collator_lm is None:
            raise ValueError(
                "data_collator_lm must be provided if input_ids are present in "
                "the batch. This error usually occurs when the model is a "
                "language model and requires input_ids for training."
            )
        others = data_collator_lm(
            [
                {
                    "input_ids": x["input_ids"],
                    "attention_mask": x["attention_mask"],
                }
                for x in batch
            ]
        )

        return {
            **ret,
            **others.data,
        }

    else:
        return ret


def collapse_data_by_trajectory(batch):
    num_examples = len(batch["instruction"])

    def gen():
        prev_inst: str = batch["instruction"][0]
        new_batch: dict[str, list] = {}

        for i in range(num_examples):
            inst = batch["instruction"][i]
            if inst != prev_inst:
                yield {
                    **new_batch,
                    "instruction": prev_inst,
                }
                new_batch = {}
            new_batch = {
                key: [
                    *new_batch.get(key, []),
                    batch[key][i],
                ]
                for key in batch.keys()
            }
            prev_inst = inst

        if new_batch:
            yield {
                **new_batch,
                "instruction": prev_inst,
            }

    ret = {}
    for data in gen():
        for key, value in data.items():
            if key not in ret:
                ret[key] = []
            ret[key].append(value)
    return ret
