import json
import logging
from pathlib import Path
import pickle
from typing import Tuple

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset


logger: logging.Logger = logging.getLogger(__name__)

FORMAT_VARIANTS = ["expanded", "set_dense", "combinatorial_dense"]


class Data:
    def __init__(
        self,
        format: str,
        dataset: str,
        batch_size: int,
        batch_size_val: int,
        val_fold_idx: int,
        data_path: str,
    ) -> None:
        assert format in FORMAT_VARIANTS

        with open(f"{data_path}/config.json") as f:
            dataset_config = {k: v for k, v in json.load(f).items() if k == dataset}
            dataset_config = dataset_config[dataset]

        if format == "combinatorial_dense":
            with open(
                Path(data_path) / dataset / "unique_concept_mapping.pkl", "rb"
            ) as f:
                unique_concept_mapping = pickle.load(f)
            # add to `unique_concept_mapping` self for evaluation and `test_file` handling
        else:
            unique_concept_mapping = None
        self.unique_concept_mapping = unique_concept_mapping

        if (dataset_config["num_q"] == 0) and (
            format in ["set_dense", "combinatorial_dense"]
        ):
            self.num_questions = 1  # For `dense` formats, we use a dummy question
        else:
            self.num_questions = dataset_config["num_q"]
        self.num_concepts = dataset_config["num_c"]
        self.max_concepts = dataset_config["max_concepts"]
        self.max_len = dataset_config["maxlen"]
        if (format in ["set_dense", "combinatorial_dense"]) and (
            dataset in ["statics2011", "poj"]
        ):
            self.swap_q_and_c = True
            logger.info(
                f"Swapping questions and concepts based on combinations of {format=} and {dataset=}"
            )
        else:
            self.swap_q_and_c = False  # default
        self.format = format

        self.train_loader, self.val_loader = get_train_valid_dataloaders(
            dataset_path=Path(dataset_config["dpath"]) / "train_valid_sequences.csv",
            dataset_format=format,
            fold_idx=val_fold_idx,
            max_concepts=self.max_concepts,
            batch_size_train=batch_size,
            batch_size_valid=batch_size_val,
            unique_concept_mapping=unique_concept_mapping,
            swap_q_and_c=self.swap_q_and_c,
        )
        # Currently, we use the test file generated by pykt to run evaluation
        # self.test_loader = get_test_dataloader(config=self.config)
        self.test_file = (
            Path(dataset_config["dpath"]) / "test_window_sequences_quelevel.csv"
        )

        # overwrite self.max_concepts and self.num_concepts for `combinatorial_dense` + keep original `max_concepts` for eval code
        if format == "combinatorial_dense":
            assert self.unique_concept_mapping is not None
            self.original_max_concepts = self.max_concepts
            self.max_concepts = 1
            self.num_concepts = len(self.unique_concept_mapping)
        else:
            self.original_max_concepts = None

        if self.swap_q_and_c:
            self.num_concepts, self.num_questions = (
                self.num_questions,
                self.num_concepts,
            )


def get_train_valid_dataloaders(
    dataset_path: Path,
    dataset_format: str,
    fold_idx: int,
    max_concepts: int,
    batch_size_train: int,
    batch_size_valid: int,
    unique_concept_mapping: dict[tuple, int] | None,
    swap_q_and_c: bool,
) -> Tuple[DataLoader, DataLoader]:
    train_dataset, valid_dataset = get_train_valid_tensor_datasets_from_pykt_dataset(
        dataset_path=dataset_path,
        dataset_format=dataset_format,
        max_concepts=max_concepts,
        fold_idx=fold_idx,
        unique_concept_mapping=unique_concept_mapping,
        swap_q_and_c=swap_q_and_c,
    )

    # Initialize torch.DataLoader with new datasets
    train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size_valid, shuffle=False)

    return train_loader, valid_loader


def get_train_valid_tensor_datasets_from_pykt_dataset(
    dataset_path: Path,
    dataset_format: str,
    max_concepts: int,
    fold_idx: int,
    unique_concept_mapping: dict[tuple, int] | None,
    swap_q_and_c: bool,
) -> tuple[TensorDataset, TensorDataset]:

    # Load data from pykt processed files
    df = pd.read_csv(dataset_path)

    train_dataset = get_tensor_dataset_from_pykt_dataset(
        df[df["fold"] != fold_idx],
        dataset_format,
        max_concepts,
        unique_concept_mapping=unique_concept_mapping,
        swap_q_and_c=swap_q_and_c,
    )
    valid_dataset = get_tensor_dataset_from_pykt_dataset(
        df[df["fold"] == fold_idx],
        dataset_format,
        max_concepts,
        unique_concept_mapping=unique_concept_mapping,
        swap_q_and_c=swap_q_and_c,
    )

    return train_dataset, valid_dataset


def get_tensor_dataset_from_pykt_dataset(
    df: pd.DataFrame,
    dataset_format: str,
    max_concepts: int,
    unique_concept_mapping: dict[tuple, int] | None,
    swap_q_and_c: bool = False,
) -> TensorDataset:

    # Extract data from df
    get_tensor = lambda x: torch.from_numpy(
        np.array(x.apply(lambda b: [int(a) for a in b.split(",")]).values.tolist())
    )

    responses = get_tensor(df["responses"])
    concepts = get_tensor(df["concepts"])
    mask = responses != -1

    if "questions" in df.columns:
        questions = get_tensor(df["questions"])
        assert torch.equal(mask, questions != -1)
    else:
        logger.info("No `questions` in dataset. Setting them to `0` everywhere.")
        questions = torch.zeros_like(responses)
        questions[~mask] = -1

    if "is_repeat" in df.columns:
        is_repeat = get_tensor(df["is_repeat"]) == 1
    else:
        logger.info("No `is_repeat` in dataset. Setting it to `False` everywhere.")
        is_repeat = torch.zeros_like(responses, dtype=torch.bool)
        is_repeat[~mask] = -1

    question_is_continued = is_repeat & mask

    # We need this hack, since sequences are split within an question
    if (dataset_format == "set_dense") or (dataset_format == "combinatorial_dense"):
        s = question_is_continued[:, 0].sum()
        logging.info(f"Set first entries in qic to False. Adds {s} interactions.")
        question_is_continued[:, 0] = False

    if swap_q_and_c:
        logger.info("Attention: swapping `concepts` and `questions`")
        concepts, questions = questions, concepts

    # Iterate over each entry / sequence and format it w.r.t. to the the dataset format
    tensor_per_sequence = []
    for i in range(len(questions)):
        q, c, r = questions[i], concepts[i], responses[i]
        qic, m = question_is_continued[i], mask[i]
        assert (q != -1).sum() == m.sum()

        if dataset_format == "expanded":
            complete_tensor = format_question_expanded(q=q, c=c, r=r, qic=qic)
        elif dataset_format == "set_dense":
            complete_tensor = format_question_set_dense(
                q=q[m],
                c=c[m],
                r=r[m],
                qic=qic[m],
                max_concepts=max_concepts,
                seq_len=len(m),
            )
        elif dataset_format == "combinatorial_dense":
            assert unique_concept_mapping is not None
            complete_tensor = format_question_combinatorial_dense(
                q=q[m],
                c=c[m],
                r=r[m],
                qic=qic[m],
                max_concepts=max_concepts,
                seq_len=len(m),
                unique_concept_mapping=unique_concept_mapping,
            )
        else:
            raise ValueError(f"Dataset format {dataset_format} is not supported.")

        tensor_per_sequence.append(complete_tensor)

    return TensorDataset(torch.stack(tensor_per_sequence, dim=0))


def format_question_expanded(q: Tensor, c: Tensor, r: Tensor, qic: Tensor) -> Tensor:
    qic = qic.type(torch.int64)
    qic[q == -1] = -1
    r = r.type(torch.int64)

    complete_tensor = torch.stack([q, c, qic, r], dim=-1)
    assert complete_tensor.shape == (len(q), 4)

    return complete_tensor


def format_question_set_dense(
    q: Tensor, c: Tensor, r: Tensor, qic: Tensor, max_concepts: int, seq_len: int
) -> Tensor:
    q_reduced = q[~qic]
    r_reduced = r[~qic]

    individual_questions = (torch.cumsum(~qic, dim=-1)) - 1
    assert individual_questions.min() == 0

    # Make dense_concept_tensor
    list_of_concept_lists = [[] for _ in range(len(q_reduced + 1))]
    # Add concepts of individual questions to separate lists
    for iq, ic in zip(individual_questions.tolist(), c.tolist()):
        list_of_concept_lists[iq].append(ic)
    # Delete duplicate concepts
    list_of_concept_lists = [list(set(l)) for l in list_of_concept_lists]
    # Pad each list to equal max_concepts in length
    for l in list_of_concept_lists:
        if len(l) > max_concepts:
            raise ValueError("Something wrong with max_concepts")
        l += [-1] * (max_concepts - len(l))
    dense_concept_tensor = torch.tensor(list_of_concept_lists)

    # Initialize new tensor and fill with relevant information
    complete_tensor = -1 * torch.ones(
        [seq_len, 1 + max_concepts + 1],
        dtype=torch.int64,
        device=q_reduced.device,
    )
    complete_tensor[: len(q_reduced), :max_concepts] = 0

    # Fill tensor
    # :max_concepts => k_hot concepts
    # -2 => qid
    # -1 => response
    complete_tensor[: len(q_reduced), :max_concepts] = dense_concept_tensor
    complete_tensor[: len(q_reduced), -2] = q_reduced
    complete_tensor[: len(q_reduced), -1] = r_reduced

    return complete_tensor


def format_question_combinatorial_dense(
    q: Tensor,
    c: Tensor,
    r: Tensor,
    qic: Tensor,
    max_concepts: int,
    seq_len: int,
    unique_concept_mapping: dict[tuple, int],
) -> Tensor:
    q_reduced = q[~qic]
    r_reduced = r[~qic]

    individual_questions = (torch.cumsum(~qic, dim=-1)) - 1
    assert individual_questions.min() == 0

    # Make dense_concept_tensor
    list_of_concept_lists = [[] for _ in range(len(q_reduced + 1))]
    # Add concepts of individual questions to separate lists
    for iq, ic in zip(individual_questions.tolist(), c.tolist()):
        list_of_concept_lists[iq].append(ic)
    # Delete duplicate concepts
    list_of_concept_lists = [list(set(l)) for l in list_of_concept_lists]
    # Check length ...
    for l in list_of_concept_lists:
        if len(l) > max_concepts:
            raise ValueError("Something wrong with max_concepts ...")

    # pass through mapping to get unique concept ids
    list_of_combined_concept_idx = []
    for l in list_of_concept_lists:
        s = tuple(sorted(set(l)))
        try:
            unique_c_idx = unique_concept_mapping[s]
        except KeyError:
            raise KeyError(
                f"Something is wrong with concept combination preprocessing, {s=} not in mapping"
            )
        list_of_combined_concept_idx.append(unique_c_idx)

    # tensorfy
    concept_tensor = torch.tensor(
        list_of_combined_concept_idx, dtype=torch.int64, device=q_reduced.device
    )
    assert concept_tensor.shape == q_reduced.shape

    # Initialize new tensor and fill with relevant information
    complete_tensor = -1 * torch.ones(
        [seq_len, 1 + 1 + 1],
        dtype=torch.int64,
        device=q_reduced.device,
    )

    # Fill tensor
    # 0 => combined_concept_idx
    # -2 => qid
    # -1 => response
    complete_tensor[: len(q_reduced), 0] = concept_tensor
    complete_tensor[: len(q_reduced), -2] = q_reduced
    complete_tensor[: len(q_reduced), -1] = r_reduced

    return complete_tensor
