import logging
from pathlib import Path
from typing import cast

import numpy as np
import pandas as pd
import torch
from sklearn import metrics

from pkg.model import BaseModel

logger: logging.Logger = logging.getLogger(__name__)


def eval(
    model: BaseModel,
    test_file: Path,
    original_max_concepts: int,
    unique_concept_mapping: dict[tuple, int],
    swap_q_and_c: bool = False,
    batch_size: int = 100,
    window_len: int = 200,
):
    device = list(model.parameters())[0].device

    df = pd.read_csv(test_file)

    ### construct test dataset
    complete_tensors = []
    complete_selectmasks = []
    logger.info(f"{len(df)=}")
    for num_row, row in df.iterrows():
        if num_row % 1000 == 0:
            logger.info(
                f"Constructing test tensors on the fly progress: {num_row/len(df)*100:.2f}%"
            )

        # get
        responses = row["responses"].split(",")
        seq_len = len(responses)
        questions = (
            ["0"] * seq_len if "questions" not in row else row["questions"].split(",")
        )
        concepts = row["concepts"].split(",")
        selectmasks = row["selectmasks"].split(",")
        assert len(questions) == len(concepts) == len(selectmasks) == seq_len

        # Make dense_concept_tensor
        # Add concepts of individual questions to separate lists
        list_of_concept_lists = [[int(x) for x in c.split("_")] for c in concepts]
        # Delete duplicate concepts
        list_of_concept_lists = [list(set(l)) for l in list_of_concept_lists]
        # Check length ...
        for l in list_of_concept_lists:
            if len(l) > original_max_concepts:
                raise ValueError("Something wrong with max_concepts ...")

        # pass through mapping to get unique concept ids
        list_of_combined_concept_idx = []
        for l in list_of_concept_lists:
            s = tuple(sorted(set(l)))
            try:
                if s == (-1,):
                    # handle padding
                    unique_c_idx = -1
                else:
                    unique_c_idx = unique_concept_mapping[s]
            except KeyError:
                raise KeyError(
                    f"Something is wrong with concept combination preprocessing, {s=} not in mapping"
                )
            list_of_combined_concept_idx.append(unique_c_idx)

        # cast others
        q = torch.tensor([int(s) for s in questions])
        r = torch.tensor([int(s) for s in responses])
        selectmasks = torch.tensor([int(s) for s in selectmasks])

        # tensorfy
        concept_tensor = torch.tensor(
            list_of_combined_concept_idx, dtype=torch.int64, device=q.device
        )
        assert concept_tensor.shape == q.shape
        c = concept_tensor

        # Initialize new tensor and fill with relevant information
        complete_tensor = -1 * torch.ones(
            [seq_len, 1 + 1 + 1],
            dtype=torch.int64,
            device=q.device,
        )

        # Fill tensor
        # 0 => combined_concept_idx
        # -2 => qid
        # -1 => response
        complete_tensor[: len(q), 0] = c
        complete_tensor[: len(q), -2] = q
        complete_tensor[: len(q), -1] = r

        complete_tensors.append(complete_tensor)
        complete_selectmasks.append(selectmasks)

    tensor_dataset = torch.stack(complete_tensors, dim=0)
    tensor_selectmasks = torch.stack(complete_selectmasks, dim=0)

    ### run eval ###
    y_trues_, y_scores_ = [], []
    tensor_selectmasks[:, 0] = (
        -1
    )  # always skip first prediction (in line with what pykt does)
    tensor_selectmasks = tensor_selectmasks == 1  # to bool

    dataset_batches = torch.split(
        tensor_dataset, split_size_or_sections=batch_size, dim=0
    )
    selectmasks_batches = torch.split(
        tensor_selectmasks, split_size_or_sections=batch_size, dim=0
    )
    logger.info(f"{len(dataset_batches)=}")
    for iteration, (batch, selectmasks_batch) in enumerate(
        zip(dataset_batches, selectmasks_batches)
    ):
        if iteration % 10 == 0:
            logger.info(f"{iteration=}")
        batch = batch.to(device)
        selectmasks_batch = selectmasks_batch.to(device)

        # through model
        mask = ~(batch == -1).all(dim=-1)
        y_pred_, *num_row_in_construction = model.forward(data=batch, padding_mask=mask)
        y_pred_, y_true_ = y_pred_, batch[:, :, -1]

        # subsetting a la pykt
        assert (selectmasks_batch == selectmasks_batch & mask).all()
        y_pred_ = torch.masked_select(y_pred_, selectmasks_batch)
        y_true_ = torch.masked_select(y_true_, selectmasks_batch)

        y_trues_.append(y_true_.detach().cpu().numpy())
        y_scores_.append(y_pred_.detach().cpu().numpy())

    y_trues_ = np.concatenate(y_trues_, axis=0)
    y_scores_ = np.concatenate(y_scores_, axis=0)
    y_preds_ = [1 if p >= 0.5 else 0 for p in y_scores_]

    auc = metrics.roc_auc_score(y_true=y_trues_, y_score=y_scores_)
    acc = metrics.accuracy_score(y_true=y_trues_, y_pred=y_preds_)

    num_evaluated_questions = len(y_preds_)

    return auc, acc, num_evaluated_questions
