# src/datasets.py
import csv
import json
import os
from pathlib import Path
import re
from abc import ABC, abstractmethod
from typing import Iterator, Dict

from src.utils import parse_answer
from src.utils import setup_output_directory
class Dataset(ABC):
    def __init__(self, data_path: str, args):
        self.data_path = data_path
        self.output_dir = setup_output_directory(args)
        self.finished_tasks = self._exist_tasks()
    @abstractmethod
    def __iter__(self) -> Iterator[Dict]:
        pass
    
    def _exist_tasks(self) -> Dict[str, bool]:
        results_file_path = Path(self.output_dir) / 'results.jsonl'
        print(f"results_file_path: {results_file_path}")
        results_tasks = {}
        if os.path.exists(results_file_path):
            with open(results_file_path, 'r') as file:
                for line in file:
                    result = json.loads(line)
                    results_tasks[result["task_id"]] = True
        print(f"Finished tasks: {(results_tasks.keys())}")
        return results_tasks
    
    def _task_exists(self, task_id: str) -> bool:
        return task_id in self.finished_tasks

class CodingDataset(Dataset):
    def __iter__(self):
        with open(self.data_path, "r") as file:
            for line in file:
                data = json.loads(line)
                data["complete_prompt"] = f"""I will provide a function signature and its docstring as follows:

{data["complete_prompt"]}

Your task is to:

1. Write a full implementation of the function
2. Include any necessary package imports
3. Restate the function signature exactly as given

Please format your entire response as a Python code block, like this:

```python
# Any necessary imports

def function_name(required_parameters):
   
    # Implementation
    return required_output
"""
                data["task_id"] = data["task_id"].split("/")[-1]
                if not self._task_exists(data["task_id"]):
                    yield data

class ChessDataset(Dataset):
    def __iter__(self):
        with open(self.data_path, "r") as file:
            for line in file:
                data = json.loads(line)
                data["complete_prompt"] = f"""I will give you a (in progress) chess games, please complete the notation for the last shown move by filling in the destination square.
# Chess Game: {data["input"]}        
## Context
- Chess moves are typically notated using the algebraic notation system.
- In this system, each square on the chessboard is identified by a letter (a-h) for the file (column) and a number (1-8) for the rank (row).
- Moves are represented by the piece's starting square followed by its destination square.

## Task Instructions
1. You will be presented with a sequence of chess moves in algebraic notation.
2. The last move is incomplete, showing only the starting square.
3. Your task is to determine the most likely destination square for the last move.
4. Provide only one the destination square using two-letter [a-h][1-8], if multiple legal moves are possible, choose the most likely one based on the game's context.

## Example

Input:
```
d2d4 g7g6 c2c4 f8g7 e2e4 g8f6 b1c3 e8g8 c1e3 f6e8 f2f3 b8c6 d1d2 e7e6 h2h4 d7d5 c4d5 e6d5 c3d5 e8f6 d5f6 g7f6 e4e5 f6g7 f1b5 c8d7 a1c1 a7a6 b5c4 b7b5 c4b3 d7f5 c1c6 d8d7 c6c5 c7c6 g1e2 a8d8 e2g3 f5e6 b3e6 d7e6 g3e4 d8d5 e1g1 f7f5 c5d5 c6d5 e4c5 e6c6 e3h6 f5f4 h6g7 g8g7 f1e1 c6e8 e5e6 e8e7 d2a5 e7h4 e6e7 f8e8 c5a6 e8e7 e1e7 h4e7 a6c7 e7e3 g1h2 e3f2 c7d5 f2
```

Output:
```
h4
```
In this example, the last move is "f2" (the starting square), completing the move as "f2h4". Note that the final answer should be two-letter following the regex [a-h][1-8].
"""
                if not self._task_exists(data["task_id"]):
                    yield data


class MathDataset(Dataset):
    def parse_answer(self, input_string):
        # Regular expression to match content inside \boxed{}
        pattern = r'\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}'
        
        # Search for the pattern in the input string
        match = re.search(pattern, input_string)
        
        if match:
            # Return the content inside \boxed{}
            return match.group(1)
        else:
            # Return None or raise an exception if no match is found
            return ""
        
    def __iter__(self):
        prompt_prefix = r"""There is a math problem. The answer must be printed out and highlighted in the format of \boxed{answer}, e.g. \boxed{1}, \boxed{\frac{1}{2}}, \boxed{\sqrt{2}}, \boxed{\pi}, \boxed{\begin{pmatrix} 1 & 2 \\ 3 & 4 \end{pmatrix}}, otherwise the answer is invalid.
        """
        with open(self.data_path, "r") as file:
            for i, line in enumerate(file):
                data = json.loads(line)
                data["task_id"] = i
                data["complete_prompt"] = prompt_prefix + "Problem: " +  data["question"]
                data["correct_answer"] = "\\boxed{" + self.parse_answer(data["correct_answer"]) + "}"
                if not self._task_exists(data["task_id"]):
                    yield data

class MultiChoiceDataset(Dataset):
    def __iter__(self):
        prompt_prefix = """
        There is a question with options, and only one option is correct. You must choose the correct option.
        Your response must be one of the options.
        Only one option is allowed in your answer.
        """
        with open(self.data_path, "r") as file:
            for i, line in enumerate(file):
                data = json.loads(line)
                data["task_id"] = i
                data["complete_prompt"] = prompt_prefix + "Question: " + data["question"]
                if not self._task_exists(data["task_id"]):
                    yield data 