import ast
import dataclasses
import functools
import math
import os
import re
from typing import Callable, List

import torch

from ast import NodeVisitor

from transformers import (
    ViTImageProcessor,
    ViTModel,
    AutoImageProcessor,
    ResNetForImageClassification,
)

from applications.code_force_problems import (
    DR_FUNC,
    DR_FUNC_NAME,
    DR_FUNC_EXPECTED_OUTPUT,
)
from applications.kaggle_problems import (
    LAST_ADD_MIN,
    verify_minimized_number,
    LAST_ADD_MIN_FUNC_NAME,
    LARGEST_INTEGER,
    LARGEST_INT_PARAMS,
)
from applications.original_tasks import (
    SIMPLE_TASK,
    SIMPLE_TASK_FUNC_NAME,
    SIMPLE_TASK_EXPECTED_OUTPUT,
    post_processor_code_creation_for_simple_task,
)

ACTIVATION_FUNCTION_NAME = "sort"
RUN_CODE = f"""
here is my code, rewrite it to minimize run time:
import random
def is_sorted(data):
    return all(data[i] <= data[i+1] for i in range(len(data) - 1))

def {ACTIVATION_FUNCTION_NAME}(data):
    attempts = 0
    while not is_sorted(data):
        random.shuffle(data)
        attempts += 1
    return data, attempts
"""

REQUEST = (
    f"Return an executable python code that performs the following task:\n"
    f"{RUN_CODE}\\n"
    f"Use a markdown syntax, i.e. the tags '''python and ''' at the beginning and the end of the code "
    f"section. "
    f"The the main entry point should be in a function called 'sort'.\n"
    f"Your answer:\n\n"
)
code_pattern = r"\'\'\'python(.*)\'\'\'"
REQUEST_FOR_CODE_MODEL = (
    f"def {ACTIVATION_FUNCTION_NAME}(arr):\n"
    f'"""This function takes an array and sort it\n'
    f'The function should run as fast as possible"""'
)
MAX_NICE_STAIRCASES_FUNC_NAME = "max_nice_staircases"
MAX_NICE_STAIRCASES = f"""
from typing import List
def max_nice_staircases(t: int, test_cases: List[int]) -> List[int]:
    \"\"\"
    Determine the maximal number of different nice staircases that can be built using no more than x cells.

    A staircase is a squared figure consisting of square cells. Each staircase has an arbitrary number of stairs.
    If a staircase has n stairs, then it is made of n columns, where the first column is 1 cell high, the second column is 2 cells high,
    ..., and the n-th column is n cells high. The lowest cells of all stairs must be in the same row.

    A staircase with n stairs is called nice if it may be covered by n disjoint squares made of cells. All squares should fully consist of cells of a staircase.

    The function takes in the number of test cases and a list of integers representing the number of cells for each test case. It returns a list of integers
    where each integer represents the maximal number of different nice staircases that can be built using no more than the given number of cells for each test case.

    Parameters:
    t (int): The number of test cases. (1 <= t <= 1000)
    test_cases (List[int]): A list of integers where each integer x represents the number of cells for building staircases. (1 <= x <= 10^18)

    Returns:
    List[int]: A list of integers representing the number of different nice staircases that can be built for each test case.
    \"\"\"
"""


def add_function_prompt(function_data, code):
    return f"{function_data.prompt}\n{code}"


def extract_code_from_regex(function_data, output_text):
    code_blocks = re.findall(r"```python\n(.*?)```", output_text, re.DOTALL)
    return code_blocks[0] if code_blocks else ""


class FunctionsFinderCallVisitors(NodeVisitor):
    def __init__(self, function_name: List[str]):
        self.function_name = function_name
        self.calls = []

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name) and node.func.id in self.function_name:
            self.calls.append(node.func.id)
        self.generic_visit(node)


def is_allowed_code(node):
    return (
        isinstance(node, ast.FunctionDef)
        or isinstance(node, ast.AsyncFunctionDef)
        or isinstance(node, ast.ClassDef)
        or isinstance(node, ast.Import)
        or isinstance(node, ast.ImportFrom)
    )


def remove_not_function_elements(code):
    try:
        parsed_code = ast.parse(code)

        # If the code has a input func call, the code is unrunable
        # Note: This should be in a different function, but for run time purposes, we put it here.
        function_finder = FunctionsFinderCallVisitors(["input"])
        function_finder.visit(parsed_code)
        if function_finder.calls:
            return ""
        allowed_nodes = [
            node for node in ast.iter_child_nodes(parsed_code) if is_allowed_code(node)
        ]
        compiled_code = os.linesep.join(ast.unparse(node) for node in allowed_nodes)
        return compiled_code
    except SyntaxError:
        return code


def add_python_initials(prompt: str):
    return f"```python\n{prompt}"


def extract_code(text):
    search_res = re.search(code_pattern, text, re.DOTALL)
    return search_res.group(1) if search_res else ""


LLAMA_MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf"
CODE_QUEEN_MODEL_CHAT = "Qwen/CodeQwen1.5-7B-Chat"
CODE_QUEEN_MODEL = "Qwen/CodeQwen1.5-7B"
DEEP_SEEQ_MODEL = "deepseek-ai/DeepSeek-Coder-V2-Lite-Base"

PARAMS_FOR_SORT = [
    [11, 4, -2, 5.2, 3.51],
    [2, -2.1, 2, 5.2, 0.5, 2.1, 11.3],
    [1, 4, -2, 5.2, 3.51],
    [10.2, 2, -2, 3.1, 3.23, 11.3],
    (torch.rand(1000) * 10).tolist(),
]
MAX_LOST = max([len(param) for param in PARAMS_FOR_SORT])


def sorting_loss(inputs, arr):
    if not isinstance(arr, list):
        return math.inf
    return sum(1 for i in range(len(arr) - 1) if arr[i] > arr[i + 1])


def fib_arr(n, results):
    if not isinstance(results, int):
        return math.inf
    values = [0] * n
    for i in range(n):
        if i < 2:
            values[i] = i
        else:
            values[i] = values[i - 1] + values[i - 2]
    return math.fabs(values[-1] - results)


def solve(n):
    ans = 0
    cnt = 1
    s = 1
    while s <= n:
        cnt = 2 * cnt + 1
        ans += 1
        n -= s
        s = (cnt * (cnt + 1)) // 2
    return ans


def max_nice_staircases(t, test_cases):
    return [solve(int(t)) for t in test_cases]


def normal_distance(params, results, func):
    real_results = func(*params) if isinstance(params, tuple) else func(params)
    if isinstance(real_results, type(results)):
        return 2
    if real_results != results:
        return 1
    return 0


def verify_against_dict(params, results, data):
    real_results = data.get(params)
    if isinstance(real_results, type(results)):
        return 2
    if real_results != results:
        return 1
    return 0


@dataclasses.dataclass
class FunctionToCreate:
    prompt: str
    loss: Callable
    params: list
    max_lost: float
    activation_func_name: str
    post_code_creator_processor: Callable = lambda x: x


@dataclasses.dataclass
class ModelsData:
    model_name: str
    manipulate_input: Callable
    extract_code: Callable


FUNCTIONS = {
    "sort": FunctionToCreate(
        REQUEST_FOR_CODE_MODEL,
        sorting_loss,
        PARAMS_FOR_SORT,
        MAX_LOST,
        ACTIVATION_FUNCTION_NAME,
    ),
    "fib": FunctionToCreate("def fib(n):", fib_arr, [10, 3, 15, 32], 1000, "fib"),
    "max_nice": FunctionToCreate(
        MAX_NICE_STAIRCASES,
        functools.partial(normal_distance, func=max_nice_staircases),
        [
            # (4, [1, 8, 6, 1000000000000000000]),
            (3, [15, 28, 3]),
            (2, [10, 55]),
            (5, [1, 2, 3, 4, 5]),
            (1, [21]),
        ],
        1000,
        MAX_NICE_STAIRCASES_FUNC_NAME,
    ),
    "min_last": FunctionToCreate(
        LAST_ADD_MIN,
        verify_minimized_number,
        [3, 4, 5, 6, 8, 10],
        -2,
        LAST_ADD_MIN_FUNC_NAME,
    ),
    "largest_int": FunctionToCreate(
        LARGEST_INTEGER,
        functools.partial(verify_against_dict, data=LARGEST_INT_PARAMS),
        list(LARGEST_INT_PARAMS.keys()),
        -2,
        "find_largest_displayable_integer",
    ),
    "dr_fun": FunctionToCreate(
        DR_FUNC,
        functools.partial(verify_against_dict, data=DR_FUNC_EXPECTED_OUTPUT),
        list(DR_FUNC_EXPECTED_OUTPUT.keys()),
        -2,
        DR_FUNC_NAME,
    ),
    "simple_task": FunctionToCreate(
        SIMPLE_TASK,
        functools.partial(verify_against_dict, data=SIMPLE_TASK_EXPECTED_OUTPUT),
        list(SIMPLE_TASK_EXPECTED_OUTPUT.keys()),
        -2,
        SIMPLE_TASK_FUNC_NAME,
        post_processor_code_creation_for_simple_task,
    ),
}


MODELS = {
    "code_queen_chat": ModelsData(
        CODE_QUEEN_MODEL_CHAT, add_python_initials, extract_code_from_regex
    ),
    "deep_seek": ModelsData(
        DEEP_SEEQ_MODEL, add_python_initials, extract_code_from_regex
    ),
}


@dataclasses.dataclass
class ImageModel:
    model_name: str
    model_class: type
    processor_class: type
    post_processor: Callable = lambda x: x


IMG_MODELS = {
    "vit": ImageModel(
        "google/vit-base-patch16-224-in21k",
        ViTModel,
        ViTImageProcessor,
        lambda x: x.pooler_output,
    ),
    "imagenet": ImageModel(
        "microsoft/resnet-50",
        ResNetForImageClassification,
        AutoImageProcessor,
        lambda x: x.logits,
    ),
    "resnet18": ImageModel(
        "microsoft/resnet-18",
        ResNetForImageClassification,
        AutoImageProcessor,
        lambda x: x.logits,
    ),
}

