from core.utils import ChatClient, build_chat_client
from typing import List, Dict, Any, Optional, Literal
import json
import requests
import datetime
from dataclasses import dataclass

from tasks.mult.utils import random_input



@dataclass
class ChatConfig:
    model_name: str = "DeepSeek-R1"
    is_reasoning_model: bool = True
    config_file: str = "configs/chat.json"
    num_threads: int = 10
    output_format: str = "json"
    question_type: Literal["sudoku", "multiplication"] = "multiplication"


def build_prompt(question: str, config: ChatConfig) -> str:
    if config.question_type == "sudoku":
        base_prompt = f"Please solve the following Sudoku puzzle:\n{question}"
    elif config.question_type == "multiplication":
        base_prompt = f"Please calculate the following multiplication:\n{question}"
    else:
        assert False
    
    if config.output_format == "json":
        base_prompt += f"\n\nPut your answer in a json object with key 'result'."
    
    if not config.is_reasoning_model:
        base_prompt += "\n\nThink step by step before answering."
    
    return base_prompt

def process_questions(questions: List[Any], config: ChatConfig) -> None:
    chat_model = build_chat_client(config.model_name, config.config_file)
    prompts = [build_prompt(q, config) for q in questions]
    
    try:
        responses = chat_model.completions(prompts, num_threads=config.num_threads)
        
        for i, (question, response) in enumerate(zip(questions, responses)):
            output_data = {
                "question": question,
                "model": config.model_name,
                "timestamp": datetime.datetime.now().isoformat(),
            }
            
            if isinstance(response, Exception):
                output_data["error"] = str(response)
                print(f"Response {i} (Error): {str(response)}")
            else:
                output_data["response"] = response.choices[0].message.content
                print(f"Response {i}: {response.choices[0].message.content}")
            
            print("---------------------------------------")
            
            filename = "responses.json"
            try:
                with open(filename, 'r') as f:
                    existing_data = json.load(f)
            except (FileNotFoundError, json.JSONDecodeError):
                existing_data = []
            
            existing_data.append(output_data)
            
            with open(filename, 'w') as f:
                json.dump(existing_data, f, indent=2)
            
    except requests.exceptions.ConnectionError:
        print("Network connection error occurred. Please check your internet connection.")
    except Exception as e:
        print(f"An unexpected error occurred: {str(e)}")

def get_questions():
    questions = []
    x_size = {i: 10 for i in range(1, 11)}
    y_size = {i: 10 for i in range(1, 11)}
    for x_digit in x_size:
        for y_digit in y_size:
            if x_digit == y_digit:
                pass
            else:
                for _ in range(10):
                    x, y = random_input(x_digit, y_digit)

                    questions.append(f"{x} * {y}")


    return questions


def main():
    config = ChatConfig(
        model_name="DeepSeek-R1",
        num_threads=3,
        output_format="json",
    )
    questions = get_questions()
    process_questions(questions, config)
