import logging
import pandas as pd
import torch
from torch.utils.data import DataLoader
import itertools

from pkg.data.pykt import get_tensor_dataset_from_pykt_dataset

STANDARD_VAL_FOLD_IDX = 4
logger: logging.Logger = logging.getLogger(__name__)


class Data:
    def __init__(
        self,
        format: str,
        num_students_train_val: int,
        num_students_test: int,
        num_questions: int,
        num_responses_per_student: int,  # sequence length in questions
        num_concepts: int,
        num_concepts_per_question: int,
        seed: int,
        guessing_prob: float,
        learning_increment: float,
        batch_size: int,
        batch_size_val_and_test: int,
        data_path: str,  # for compatibility with standard `train.py`
        student_offset: float = 0.0,
    ) -> None:

        assert num_students_train_val % 5 == 0

        # for compatibility with standard `train.py`

        self._generate(
            num_students_train_val=num_students_train_val,
            num_students_test=num_students_test,
            num_questions=num_questions,
            num_responses_per_student=num_responses_per_student,
            num_concepts=num_concepts,
            num_concepts_per_question=num_concepts_per_question,
            guessing_prob=guessing_prob,
            learning_increment=learning_increment,
            seed=seed,
            student_offset=student_offset,
        )

        if format == "combinatorial_dense":
            self.unique_concept_mapping = {
                k: i
                for i, k in enumerate(
                    itertools.combinations(
                        range(num_concepts), num_concepts_per_question
                    )
                )
            }
        else:
            self.unique_concept_mapping = None

        self.num_concepts = num_concepts
        self.num_questions = num_questions
        self.max_concepts = num_concepts_per_question
        self.max_len = num_responses_per_student * num_concepts_per_question

        df = self.data[self.data["fold"] != -1]

        train_dataset = get_tensor_dataset_from_pykt_dataset(
            df=df[df["fold"] != STANDARD_VAL_FOLD_IDX],
            dataset_format=format,
            max_concepts=self.max_concepts,
            unique_concept_mapping=self.unique_concept_mapping,
            swap_q_and_c=False,
        )
        val_dataset = get_tensor_dataset_from_pykt_dataset(
            df=df[df["fold"] == STANDARD_VAL_FOLD_IDX],
            dataset_format=format,
            max_concepts=self.max_concepts,
            unique_concept_mapping=self.unique_concept_mapping,
            swap_q_and_c=False,
        )
        test_dataset = get_tensor_dataset_from_pykt_dataset(
            df=self.data[self.data["fold"] == -1],
            dataset_format=format,
            max_concepts=self.max_concepts,
            unique_concept_mapping=self.unique_concept_mapping,
            swap_q_and_c=False,
        )

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )
        self.val_loader = DataLoader(
            val_dataset, batch_size=batch_size_val_and_test, shuffle=False
        )
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size_val_and_test, shuffle=False
        )

        self.test_file = self.data[self.data["fold"] == -1]

        # overwrite self.max_concepts and self.num_concepts for `combinatorial_dense` + keep original `max_concepts` for eval code
        if format == "combinatorial_dense":
            assert self.unique_concept_mapping is not None
            self.original_max_concepts = self.max_concepts
            self.max_concepts = 1
            self.num_concepts = len(self.unique_concept_mapping)
        else:
            self.original_max_concepts = None

    def _generate(
        self,
        num_students_train_val: int,
        num_students_test: int,
        num_questions: int,
        num_responses_per_student: int,
        num_concepts: int,
        num_concepts_per_question: int,
        guessing_prob: float,
        learning_increment: float,
        seed: int,
        student_offset: float,
    ) -> None:
        torch.manual_seed(seed)

        num_students = num_students_train_val + num_students_test

        # Sample item parameters
        a = torch.zeros((num_questions, num_concepts), dtype=torch.int64)
        a[:, :num_concepts_per_question] = 1  # question to concept mapping
        for j in range(num_questions):
            a[j] = a[j, torch.randperm(n=num_concepts)]
        b = torch.randn((num_questions, num_concepts))  # question diffic. per concept
        c = guessing_prob

        # Sample init students
        students = torch.randn((num_students, num_concepts)) + student_offset

        # Sample question sequence per student
        student_question_seq = torch.ones(
            (num_students, num_responses_per_student), dtype=torch.int64
        ).cumsum(dim=1)
        for j in range(num_students):
            student_question_seq[j] = torch.randperm(n=num_questions)[
                :num_responses_per_student
            ]

        # Init responses table
        responses = torch.zeros(
            (num_students, num_responses_per_student),
            dtype=torch.int64,
        )
        response_probs = torch.zeros(
            (num_students, num_responses_per_student),
            dtype=torch.float32,
        )

        is_repeat = torch.ones((1, num_concepts_per_question), dtype=torch.int64)
        is_repeat[:, 0] = 0
        is_repeat = is_repeat.repeat((num_students, num_responses_per_student))
        assert is_repeat.shape == (
            num_students,
            num_responses_per_student * num_concepts_per_question,
        )

        # Run simulation
        for s in range(len(students)):
            for t in range(num_responses_per_student):
                q = student_question_seq[s, t]
                response_prob = c + (1 - c) * 1 / (
                    1
                    + torch.exp(
                        -torch.inner(a[q].type(torch.float32), students[s] - b[q])
                    )
                )
                responses[s, t] = torch.rand(1) < response_prob
                response_probs[s, t] = response_prob
                students[s, a[q].type(torch.bool)] = (
                    students[s, a[q].type(torch.bool)] + learning_increment
                )

        # Construct dataframe
        data = []
        to_str = lambda X: ",".join([str(x) for x in X.tolist()])
        for s in range(len(students)):
            _questions = student_question_seq[s].repeat_interleave(
                num_concepts_per_question
            )
            _concepts = a[student_question_seq[s]].nonzero(as_tuple=True)[1]
            _responses = responses[s].repeat_interleave(num_concepts_per_question)
            _response_probs = response_probs[s].repeat_interleave(
                num_concepts_per_question
            )
            _is_repeat = is_repeat[s]

            data.append(
                {
                    "fold": -1,
                    "uid": s,
                    "questions": to_str(_questions),
                    "concepts": to_str(_concepts),
                    "responses": to_str(_responses),
                    "is_repeat": to_str(_is_repeat),
                    "_response_probs": ",".join(
                        [f"{x:.2f}" for x in _response_probs.tolist()]
                    ),
                }
            )

        self.data = pd.DataFrame(data)
        train_folds = torch.tensor([0, 1, 2, 3, 4])
        train_folds = train_folds.repeat_interleave(num_students_train_val // 5)
        fold = -1 * torch.ones(len(self.data), dtype=torch.int64)
        fold[: len(train_folds)] = train_folds
        self.data["fold"] = fold


if __name__ == "__main__":
    import numpy as np
    import matplotlib.pyplot as plt

    num_concepts_per_question = 3
    data = Data(
        format="set_dense",
        num_students_train_val=100,
        num_students_test=100,
        num_questions=1000,
        num_responses_per_student=40,
        num_concepts=10,
        num_concepts_per_question=num_concepts_per_question,
        seed=0,
        guessing_prob=0.0,
        learning_increment=0.1,
        batch_size=100,
        batch_size_val_and_test=100,
        data_path="foo",
    )
    df = data.test_file
    b = next(iter(data.train_loader))
    tmp = b[0]
    print("debug")
