import os
import tqdm
import json
import time
import random
from typing import Literal
from core.domain.schema import ProblemDomain, Problem, BinaryProblem
from utils.io_utils import load_file


class Forecasting(ProblemDomain):
    """Forecasting as a problem domain, with access to ground truth."""

    def __init__(
        self,
        dataset_file: str = "metaculus_resolved_binary.json",
        train_size: float = 0.8,
    ):
        """Instantiate a forecasting problem set.

        :param dataset_file: dataset filepath relative to `data/questions/`, defaults to "metaculus_resolved_binary.json"
        :type dataset_file: str, optional
        :param train_size: the portion of samples to serve as training samples, defaults to 0.8
        :type train_size: float, optional
        """
        self.train_size = train_size

        # Access debating data
        self.dataset_path = os.path.join("data", "questions", dataset_file)
        self.dataset_content = load_file(self.dataset_path)

        # Parse questions
        self.questions_all = [
            BinaryProblem(
                id=q["id"],
                question=(
                    q["description"].replace("\n\n", "\n") + " " + q["title"]
                    if "description" in q and q["description"]
                    else q["title"]
                ),
                options=q["outcomes"],
                correct_option=q["outcomes"].index(q["resolution"]),
                aux_info=q,
            )
            for q_id, q in enumerate(self.dataset_content)
            if q["marketType"] == "binary" and q["resolution"] in q["outcomes"]
        ]
        random.shuffle(self.questions_all)

        # Partition questions
        train_samples = int(len(self.questions_all) * self.train_size)
        self.questions_splits = {
            "train": self.questions_all[:train_samples],
            "test": self.questions_all[train_samples:],
        }
        print(f"Training set size: {len(self.questions_splits['train'])}")
        print(f"Test set size: {len(self.questions_splits['test'])}")

    def sample_problems(
        self, n: int = 1, split: Literal["train", "test"] = "train"
    ) -> list[BinaryProblem]:
        """Sample a number of problems from a dataset split. The splitting is performed during instantiation."""
        samples = random.sample(self.questions_splits[split], n)
        samples = [s.shuffle_options() for s in samples]
        return samples
