import os
import replicate
import re
import ast
from typing import Optional
import time
import random

from rizaio import Riza
from openai import OpenAI
import google.generativeai as genai
from anthropic import Anthropic, APIStatusError, AnthropicError
from dotenv import load_dotenv

from .config import Config

import ast


config = Config()

load_dotenv()

ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_KEY")
GOOGLE_API_TOKEN = os.getenv("GOOGLE_API_KEY")


client_anthropic = Anthropic(api_key=ANTHROPIC_API_KEY)
client_openai = OpenAI(api_key=OPENAI_API_KEY)
genai.configure(api_key=GOOGLE_API_TOKEN)
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
api = replicate.Client(api_token=os.environ["REPLICATE_API_TOKEN"])


RIZA_API_KEY = os.getenv("RIZA_API_KEY")
os.environ["RIZA_API_KEY"] = RIZA_API_KEY
client_riza = Riza()


import os
import openai
from sklearn.metrics.pairwise import cosine_similarity

openai.api_key = os.environ["OPENAI_API_KEY"]


def get_mc_options(model, question, gt_answer, N_options=10):
    options = get_N_MC_answers(
        model=model, problem=question, gt_answer=gt_answer, N=N_options
    )

    try:
        if isinstance(options, str):
            # Using ast.literal_eval is safer than eval() for parsing literals
            # as it prevents execution of arbitrary code.
            options = ast.literal_eval(options)

        # Convert to a set to get unique options, then back to a list
        unique_options = list(set(options))

        # Check if there are the required number of unique options.
        if len(unique_options) < N_options:
            return None

        return unique_options

    except (ValueError, SyntaxError, TypeError) as e:
        # Catch specific exceptions that ast.literal_eval might raise,
        # or if the resulting object isn't iterable/list-like.
        print(f"Error processing options: {e}")
        return None
    except Exception as e:
        # Catch any other unexpected errors
        print(f"An unexpected error occurred: {e}")
        return None


def is_sufficiently_different(q_embedding, other_embeddings, threshold=0.336):
    """
    Check if the question embedding is sufficiently different from other embeddings.
    Uses cosine similarity to determine if the question is distinct enough.
    """
    similarities = cosine_similarity([q_embedding], other_embeddings)
    max_similarity = min(1 - similarities[0])
    return max_similarity > threshold  # Return True if sufficiently different


def get_embedding(question):
    response = openai.embeddings.create(
        model="text-embedding-3-small",
        input=question,
    )

    # Extract embeddings
    embeddings = [item.embedding for item in response.data]

    return embeddings[0]


def get_ground_truth_answer(code: str) -> Optional[str]:
    try:
        # Attempt to execute the code using the command.exec API
        resp = client_riza.command.exec(
            language="python",
            code=code,
        )

        # If the response is successful, return the stdout
        answer = resp.stdout
        answer = answer.replace("\n", "")
        if answer != "":
            return answer
        else:
            return None

    except Exception as e:
        # If there is an error during execution (e.g., invalid code, network issue)
        print(f"Error during COP question execution: {e}")
        return None


def get_completion(
    model: str,
    messages: list[dict],
    system_message: str = None,
    temperature=0.0,
) -> str:
    match config.MODEL_TO_PROVIDER_MAP[model]:
        case "openai":
            backoff = 2
            for attempt in range(5):
                try:
                    return (
                        client_openai.chat.completions.create(
                            messages=messages, model=model, temperature=temperature
                        )
                        .choices[0]
                        .message.content
                    )
                except Exception as e:
                    print(f"Error with OpenAI API: {e}")
                    time.sleep(backoff)
                    backoff *= 2
            raise Exception("Max retries reached, OpenAI API is still unavailable.")

        case "google":
            backoff = 2
            for attempt in range(5):
                try:
                    model_google = genai.GenerativeModel(model_name=model)
                    response = model_google.generate_content(
                        messages, generation_config={"temperature": temperature}
                    )
                    return response.text
                except Exception as e:
                    print(f"Error with Google API: {e}")
                    time.sleep(backoff)
                    backoff *= 2
            raise Exception("Max retries reached, Google API is still unavailable.")

        case "anthropic":
            backoff = 2
            for attempt in range(5):
                try:
                    if model == "claude-sonnet-4-20250514":
                        output = client_anthropic.messages.create(
                            model=model,
                            system=system_message,
                            messages=messages,
                            thinking={"type": "enabled", "budget_tokens": 1024},
                            max_tokens=2048,
                        ).content
                        for i in output:
                            if i.type == "text":
                                return i.text

                        # If model is stuck in thinking mode - return empty string.
                        print("Model stuck in thinking mode, did not find an answer.")
                        return None
                        # print(
                        #    f"Attempt {attempt + 1}: Anthropic API returned content but no 'text' type. Retrying..."
                        # )
                        # print(f"this is what was returned: {output}")
                        # time.sleep(backoff)
                        # backoff *= 2

                    else:
                        response = client_anthropic.messages.create(
                            model=model,
                            system=system_message,
                            messages=messages,
                            max_tokens=1024,
                            temperature=temperature,
                        )
                        if (
                            response.content
                            and len(response.content) > 0
                            and hasattr(response.content[0], "text")
                        ):
                            return response.content[0].text
                        else:
                            print(
                                f"Attempt {attempt + 1}: Anthropic API returned empty or malformed content. Retrying..."
                            )
                            time.sleep(backoff)
                            backoff *= 2

                except APIStatusError as e:
                    if e.status_code == 529:
                        print(f"Anthropic API overloaded, retrying in {backoff}s...")
                        time.sleep(backoff)
                        backoff *= 2
                    elif e.status_code == 500:
                        print(
                            f"Anthropic Internal Server Error (500), retrying in {backoff}s..."
                        )
                        time.sleep(backoff)
                        backoff *= 2
                    else:
                        print(
                            f"Attempt {attempt + 1}: Anthropic API error (status {e.status_code}): {e}"
                        )
                        raise  # Re-raise if it's not a retriable status code

                except (
                    AnthropicError
                ) as e:  # Catch a broader base error from the Anthropic library
                    print(
                        f"Attempt {attempt + 1}: General Anthropic client error: {e}. Retrying..."
                    )
                    time.sleep(backoff)
                    backoff *= 2

                except Exception as e:
                    # This will catch *any* other unexpected error
                    print(
                        f"Attempt {attempt + 1}: An unexpected error occurred: {type(e).__name__}: {e}"
                    )
                    # Consider if you want to retry for all exceptions or just re-raise
                    time.sleep(backoff)
                    backoff *= 2

            raise Exception("Max retries reached, API is still unavailable.")

        case "meta":
            full_prompt = (
                f"<<SYS>>\n{system_message}\n<</SYS>>\n\n{messages[0]['content']}"
            )
            output = api.run("meta/llama-2-13b-chat", input={"prompt": full_prompt})
            return output

        case _:
            raise NotImplementedError()


def get_task(
    model: str,
    prompt: str,
    system_message=(
        "You are an AI participating in a competition. Your job is to set questions."
    ),
    temperature=0.7,
) -> str:
    """Have model generate a problem it thinks it will excel at."""

    match config.MODEL_TO_PROVIDER_MAP[model]:
        case "openai":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]

            response = get_completion(
                model=model,
                messages=messages,
                system_message=None,
                temperature=temperature,
            )

            return response

        case "google":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]
            messages = format_messages_openai_to_google(messages)
            response = get_completion(
                model=model,
                messages=messages,
                system_message=None,
                temperature=temperature,
            )

            return response

        case "anthropic":
            messages = [{"role": "user", "content": prompt}]
            response = get_completion(
                model=model,
                messages=messages,
                system_message=system_message,
                temperature=temperature,
            )

    return response


def get_N_verifiable_problems(
    model: str,
    prompt: str,
    system_message=(
        "You are an AI participating in a competition. Your job is to set questions."
    ),
    temperature=0.0,
) -> str:
    """Have model generate a problem it thinks it will excel at."""

    match config.MODEL_TO_PROVIDER_MAP[model]:
        case "openai":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]

            response = get_completion(model, messages, temperature)

            return response

        case "google":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]
            messages = format_messages_openai_to_google(messages)
            response = get_completion(model, messages, temperature)

            return response

        case "anthropic":
            messages = [{"role": "user", "content": prompt}]
            response = get_completion(model, messages, system_message)

    return response


def format_messages_openai_to_google(messages):
    return "\n".join([f"{m['role']}: {m['content']}" for m in messages])


# def get_verifiable_problem(
#    model: str,
#    prompt: str,
#    system_message=(
#        "You are an AI participating in a competition. Your job is to set questions."
#    ),
#    temperature=0.0,
# ) -> str:
#    """Have model generate a problem it thinks it will excel at."""

#    match config.MODEL_TO_PROVIDER_MAP[model]:
#        case "google":
#            messages = [
#                {"role": "system", "content": system_message},
#                {"role": "user", "content": prompt},
#            ]
#            messages = format_messages_openai_to_google(messages)

#            response = get_completion(model, messages, temperature)

#            match = re.search(r"```python(.*?)```", response, re.DOTALL)

#        case "openai":
#            messages = [
#                {"role": "system", "content": system_message},
#                {"role": "user", "content": prompt},
#            ]

#            response = get_completion(model, messages, temperature)

#            match = re.search(r"```python(.*?)```", response, re.DOTALL)

#        case "anthropic":
#            messages = [{"role": "user", "content": prompt}]
#            response = get_completion(model, messages, system_message)
#            response = response[0].text#

#            match = re.search(r"```python(.*?)```", response, re.DOTALL)
#            match2 = re.search(r"```\n(.*?)```", response, re.DOTALL)
#            match3 = re.search(r"\n(def .*?)$", response, re.DOTALL)

#            if match:
#                return match.group(1).strip()
#            elif match2:
#                return match2.group(1).strip()
#            elif match3:
#                return match3.group(1).strip()
#            else:
#                return response

#        case "meta":
#            messages = [{"role": "user", "content": prompt}]
#            response = get_completion(model, messages, system_message)

#            response = "".join(response)

#            match = re.search(r"```python(.*?)```", response, re.DOTALL)
#            match2 = re.search(r"```\n(.*?)```", response, re.DOTALL)

#            if match:
#                return match.group(1).strip()
#            elif match2:
#                return match2.group(1).strip()
#            else:
#                return response
#    # output processing - keep only the executable python code
#    if match:
#        return match.group(1).strip()  # Remove leading/trailing whitespace
#    return ""  # Return empty string if no match is found


def get_answer_to_verifiable_problem(model: str, problem: str) -> str:
    # TODO: Implement adequate prompt for the model to solve the problem
    prompt = f"""Solve this problem designed to test your capabilities.
               Respond with only your answer. Do not include additional output or explanation.

               PROBLEM: {problem}"""

    system_message = "You are an AI participating in a reasoning competition."
    match config.MODEL_TO_PROVIDER_MAP[model]:
        case "openai":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]

            response = get_completion(model, messages)

        case "anthropic":
            messages = [{"role": "user", "content": prompt}]
            response = get_completion(model, messages, system_message)

        case "meta":
            messages = [{"role": "user", "content": prompt}]
            response = get_completion(model, messages, system_message)
            response = "".join(response)

    # TODO: Implement any postprocessing of the response to extract the answer
    return response


# def get_answer_to_verifiable_MC_problem_with_prior_q(
#    model: str,
#    problem: str,
#    options: list[str],
#    second_q: str,
#    gt_answer_to_second_q: str,
#    N_samples: int = 1,
# ) -> str:
#    prompt = f"""Solve this problem designed to test your capabilities.
#               Respond with only your answer. Do not include additional output or explanation.
#
#               To help you with your reasoning, here is a prior question:
#               {second_q}
#               The correct answer to the prior question is: {gt_answer_to_second_q}.##

#               Pick the correct option from the following choices: {options}
#               Remember to respond only with one of these options. #

#               PROBLEM: {problem}"""

#    if N_samples > 1:
#        system_message = [
#            {
#                "type": "text",
#                "text": "You are an AI participating in a reasoning competition"
#                + prompt,
#                "cache_control": {"type": "ephemeral"},
#            }
#        ]
#    else:
#        system_message = (
#            "You are an AI participating in a reasoning competition." + prompt
#        )
#    basic_prompt = "Answer the question."

#    match config.MODEL_TO_PROVIDER_MAP[model]:
#        case "openai":
#            messages = [
#                {"role": "system", "content": system_message},
#                {"role": "user", "content": prompt},
#            ]

#            response = get_completion(model, messages)

#        case "anthropic":
#            if N_samples > 1:
#                messages = [{"role": "user", "content": basic_prompt}]
#                response = [
#                    get_completion(model, messages, system_message)
#                    for _ in range(N_samples)
#                ]
#                response = [normalize_answer(i[0].text) for i in response]
#            else:
#                messages = [{"role": "user", "content": prompt}]
#                response = get_completion(model, messages, system_message)
#                response = normalize_answer(response[0].text)

#        case "meta":
#            messages = [{"role": "user", "content": prompt}]
#            response = get_completion(model, messages, system_message)
#            response = "".join(response)

#    # TODO: Implement any postprocessing of the response to extract the answer
#    return response


def get_answer_to_verifiable_MC_problem_with_ABCD(
    model: str, problem: str, options: list[str], N_samples: int = 1
) -> str:
    prompt = f"""Solve this problem designed to test your capabilities.
               Respond with only your answer. Do not include additional output or explanation.

               Pick the correct option from the following choices: 
               A: {options[0]} B: {options[1]} C: {options[2]} D: {options[3]}

               Return only A, B, C, D. Nothing else.
               PROBLEM: {problem}"""

    if N_samples > 1:
        system_message = [
            {
                "type": "text",
                "text": "You are an AI participating in a reasoning competition"
                + prompt,
                "cache_control": {"type": "ephemeral"},
            }
        ]
    else:
        system_message = (
            "You are an AI participating in a reasoning competition." + prompt
        )
    basic_prompt = "Answer the question."

    match config.MODEL_TO_PROVIDER_MAP[model]:
        case "openai":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]

            response = get_completion(model, messages)

        case "google":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]
            messages = format_messages_openai_to_google(messages)
            response = get_completion(model, messages)

        case "anthropic":
            if N_samples > 1:
                messages = [{"role": "user", "content": basic_prompt}]
                response = [
                    get_completion(model, messages, system_message)
                    for _ in range(N_samples)
                ]
                response = [normalize_answer(i) for i in response]
            else:
                messages = [{"role": "user", "content": prompt}]
                response = get_completion(model, messages, system_message)
                if response:
                    response = normalize_answer(response)

    return response


# def get_answer_to_verifiable_MC_problem(
#    model: str, problem: str, options: list[str], N_samples: int = 1
# ) -> str:
#    # TODO: Implement adequate prompt for the model to solve the problem
#    prompt = f"""Solve this problem designed to test your capabilities.
#               Respond with only your answer. Do not include additional output or explanation.#

#               Pick the correct option from the following choices: {options}

#               PROBLEM: {problem}"""

#    if N_samples > 1:
#        system_message = [
#            {
#                "type": "text",
#                "text": "You are an AI participating in a reasoning competition"
#                + prompt,
#                "cache_control": {"type": "ephemeral"},
#            }
#        ]
#    else:
#        system_message = (
#            "You are an AI participating in a reasoning competition." + prompt
#        )
#    basic_prompt = "Answer the question."

#    match config.MODEL_TO_PROVIDER_MAP[model]:
#        case "openai":
#            messages = [
#                {"role": "system", "content": system_message},
#                {"role": "user", "content": prompt},
#            ]

#            response = get_completion(model, messages)

#        case "google":
#            messages = [
#                {"role": "system", "content": system_message},
#                {"role": "user", "content": prompt},
#            ]
#            messages = format_messages_openai_to_google(messages)
#            response = get_completion(model, messages)

#        case "anthropic":
#            if N_samples > 1:
#                messages = [{"role": "user", "content": basic_prompt}]
#                response = [
#                    get_completion(model, messages, system_message)
#                    for _ in range(N_samples)
#                ]
#                response = [normalize_answer(i[0].text) for i in response]
#            else:
#                messages = [{"role": "user", "content": prompt}]
#                response = get_completion(model, messages, system_message)
#                response = normalize_answer(response[0].text)

#        case "meta":
#            messages = [{"role": "user", "content": prompt}]
#            response = get_completion(model, messages, system_message)
#            response = "".join(response)

#    # TODO: Implement any postprocessing of the response to extract the answer
#    return response


# def get_MC_answers(model: str, problem: str) -> str:
#    prompt = f"""
#    Generate exactly four plausible multiple-choice answers to the following question.

#    - The answers should be diverse and reasonable.
#    - Respond **only** with the four answers in this format: ["answer1", "answer2", "answer3", "answer4"]
#    - Each answer should be in *exactly the format* that would be expected as output if you ran the code.
#    - Do not include any explanations or additional text.

#    PROBLEM: {problem}
#    """

#    system_message = "You are an AI participating in a reasoning competition."
#    match config.MODEL_TO_PROVIDER_MAP[model]:
#        case "openai":
#            messages = [
#                {"role": "system", "content": system_message},
#                {"role": "user", "content": prompt},
#            ]

#            response = get_completion(model, messages)

#        case "anthropic":
#            messages = [{"role": "user", "content": prompt}]
#            response = get_completion(model, messages, system_message)
#            response = response[0].text

#        case "meta":
#            messages = [{"role": "user", "content": prompt}]
#            response = get_completion(model, messages, system_message)
#            response = "".join(response)

#    response = ast.literal_eval(response)
#    return response


def normalize_answer(answer):
    # Strip whitespace and remove optional surrounding single or double quotes
    answer = answer.strip()
    if (answer.startswith("'") and answer.endswith("'")) or (
        answer.startswith('"') and answer.endswith('"')
    ):
        return answer[1:-1]
    return answer


def get_N_MC_answers(model: str, problem: str, gt_answer, N=10) -> str:
    prompt = f"""
    Generate exactly {N} plausible multiple-choice answers to the following question.  

    - The answers should be diverse and reasonable.  
    - Respond **only** with the {N} answers in this format: ["answer1", "answer2", "answer3", "answer4",..etc] 
    - Each answer should be in *exactly the format* that would be expected as output if you ran the code.
    - Do not include any explanations or additional text.
    - Use the ground truth answer to craft your options  
    - Every option should be different
    - If you can't think of plausible options, just pick random possible answers.

    PROBLEM: {problem}
    GROUND_TRUTH: {gt_answer}
    """

    system_message = "You are an AI participating in a reasoning competition."
    match config.MODEL_TO_PROVIDER_MAP[model]:
        case "openai":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]

            response = get_completion(model, messages)

        case "google":
            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": prompt},
            ]
            messages = format_messages_openai_to_google(messages)
            response = get_completion(model, messages)
            return response

        case "anthropic":
            messages = [{"role": "user", "content": prompt}]
            response = get_completion(model, messages, system_message)
            return response

        case "meta":
            messages = [{"role": "user", "content": prompt}]
            response = get_completion(model, messages, system_message)
            response = "".join(response)

    response = ast.literal_eval(response)
    return response


def generate_random_subsets(mc_answers, gt_answer, num_subsets=5):
    subsets = []
    for _ in range(num_subsets):
        subset = random.sample([ans for ans in mc_answers if ans != gt_answer], 3)
        subset.append(gt_answer)
        random.shuffle(subset)
        subsets.append(subset)
    return subsets
