import logging

from generators.model import ModelBase, Message
import random

from typing import Union, List, Optional, Callable

############################# simple
def generic_generate_func_impl(
    func_sig: str,
    model: ModelBase,
    strategy: str,
    prev_func_impl,
    feedback,
    self_reflection,
    num_comps,
    temperature,
    reflection_chat_instruction: str,
    reflection_few_shot: str,
    simple_chat_instruction: str,
    reflection_completion_instruction: str,
    simple_completion_instruction: str,
    code_block_instruction: str,
    parse_code_block: Callable[[str], str],
    add_code_block: Callable[[str], str]
) -> Union[str, List[str]]:
    if strategy != "reflexion" and strategy != "simple":
        raise ValueError(
            f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
    if strategy == "reflexion" and (prev_func_impl is None or feedback is None or self_reflection is None):
        raise ValueError(
            f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None")

    if model.is_chat:
        if strategy == "reflexion":
            message = f"{reflection_few_shot}\n[previous impl]:\n{add_code_block(prev_func_impl)}\n\n[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{self_reflection}\n\n[improved impl]:\n{func_sig}"
            prompt = f"{reflection_chat_instruction}\n{code_block_instruction}"
            # func_bodies is a really bad name, as it can also be just 1 string
            print_messages(prompt, message)
            messages = [
                Message(
                    role="system",
                    content=prompt,
                ),
                Message(
                    role="user", # TODO: check this
                    content=reflection_few_shot,
                ),
                Message(
                    role="assistant",
                    content=add_code_block(prev_func_impl),
                ),
                Message(
                    role="user",
                    content=f"[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:",
                ),
                Message(
                    role="assistant",
                    content=self_reflection,
                ),
                Message(
                    role="user",
                    content=f"[improved impl]:\n{func_sig}",
                ),
            ]
            func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
        else:  ####
            system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
            print_messages(system_prompt, func_sig)
            messages = [
                Message(
                    role="system",
                    content=f"{simple_chat_instruction}\n{code_block_instruction}",
                ),
                Message(
                    role="user",
                    content=func_sig,
                ),
            ]
            func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
    else:
        if strategy == "reflexion":
            prompt = f"{reflection_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
            func_bodies = model.generate(
                prompt, num_comps=num_comps, temperature=temperature)
        else:
            prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}"
            func_bodies = model.generate(
                prompt, num_comps=num_comps, temperature=temperature)

    if num_comps == 1:
        assert isinstance(func_bodies[0], str)
        func_body_str = parse_code_block(func_bodies[0])  #check code
        print_generated_func_body(func_body_str)
        logging.info(f"SAMPLED ACTION: {func_bodies}")
        return [func_body_str,func_bodies[1]]

    else:
        func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
        print_generated_func_body("\n\n".join(func_bodies))
        logging.info(f"SAMPLED ACTION: {func_bodies}")
        return func_bodies

#####################
def generate_with_accumulated_context(
    func_sig: str,
    model: ModelBase,
    strategy: str,
    prev_func_impl,
    accumulated_feedback,
    accumulated_reflection,
    num_comps,
    temperature,
    reflection_chat_instruction: str,
    reflection_few_shot: str,
    simple_chat_instruction: str,
    reflection_completion_instruction: str,
    simple_completion_instruction: str,
    code_block_instruction: str,
    parse_code_block: Callable[[str], str],
    add_code_block: Callable[[str], str]
) -> Union[str, List[str]]:
    # Ensure that the strategy is valid
    if strategy != "reflexion" and strategy != "simple":
        raise ValueError(
            f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
    if strategy == "reflexion" and (prev_func_impl is None or accumulated_feedback is None or accumulated_reflection is None):
        raise ValueError(
            f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None")

    # Build the accumulated context from the provided feedback and reflections
    accumulated_context = "\n\n".join(
        [f"[previous impl {i+1}]:\n{add_code_block(impl)}\n[unit test results from previous impl {i+1}]:\n{feedback}\n[reflection on previous impl {i+1}]:\n{reflection}" 
         for i, (impl, feedback, reflection) in enumerate(zip(prev_func_impl, accumulated_feedback, accumulated_reflection))]
    )

    if model.is_chat:
        if strategy == "reflexion":
            # Constructing the message using a loop for accumulated context
            messages = [
                Message(role="system", content=f"{reflection_chat_instruction}\n{code_block_instruction}"),
                Message(role="user", content=reflection_few_shot)
            ]
            
            for impl, feedback, reflection in zip(prev_func_impl, accumulated_feedback, accumulated_reflection):
                messages.append(Message(role="assistant", content=add_code_block(impl)))
                messages.append(Message(role="user", content=f"[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{reflection}"))
            
            messages.append(Message(role="user", content=f"[improved impl]:\n{func_sig}"))
            prompt = "\n".join([message.content for message in messages])
            message = (f"{reflection_few_shot}\n{accumulated_context}\n\n[improved impl]:\n{func_sig}")
            print_messages(prompt, message)
            ##########################
            func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
        else:
            system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
            print_messages(system_prompt, func_sig)
            messages = [
                Message(role="system", content=f"{simple_chat_instruction}\n{code_block_instruction}"),
                Message(role="user", content=func_sig)
            ]
            func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
    else:
        if strategy == "reflexion":
            prompt = f"{reflection_completion_instruction}\n{accumulated_context}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
            func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature)
            print_messages(prompt, "")  
        else:
            prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}"
            func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature)
            print_messages(prompt, "")
    ################################
    if num_comps == 1:
        assert isinstance(func_bodies[0], str)
        func_body_str = parse_code_block(func_bodies[0])
        print_generated_func_body(func_body_str)
        return [func_body_str,func_bodies[1]]

    else:
        func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
        print_generated_func_body("\n\n".join(func_bodies))
        return func_bodies
    

def generic_generate_internal_tests(
        func_sig: str,
        model: ModelBase,
        max_num_tests: int,
        test_generation_few_shot: str,
        test_generation_chat_instruction: str,
        test_generation_completion_instruction: str,
        parse_tests: Callable[[str], List[str]],
        is_syntax_valid: Callable[[str], bool],
        is_react: bool = False
) -> List[str]:
    """Generates tests for a function."""
    if model.is_chat:
        if is_react:
            messages = [
                Message(
                    role="system",
                    content=test_generation_chat_instruction,
                ),
                Message(
                    role="user",
                    content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\nThere cannot be any data that is unreasonable for this method\n[think]:"
                )
            ]
            output = model.generate_chat(messages=messages, max_tokens=1024)
            print(f'React test generation output: {output}')
        else:
            messages = [
                Message(
                    role="system",
                    content=f"{test_generation_chat_instruction}\n\n{test_generation_few_shot}",
                ),
                Message(
                    role="user",
                    content=f"[func signature]:\n{func_sig}\nThere cannot be any data that is unreasonable for this method\n[unit tests]:",
                )
            ]
            output = model.generate_chat(messages=messages, max_tokens=1024)
    else:
        prompt = f'{test_generation_completion_instruction}\n\nfunc signature:\n{func_sig}\nThere cannot be any data that is unreasonable for this method\nunit tests:'
        output = model.generate(prompt, max_tokens=1024)   # output
    all_tests = parse_tests(output)  # type:  
    valid_tests = [test for test in all_tests if is_syntax_valid(test)]
    return sample_n_random(valid_tests, max_num_tests)


def generic_generate_self_reflection(
        func: str,
        feedback: str,
        model: ModelBase,
        self_reflection_chat_instruction: str,
        self_reflection_completion_instruction: str,
        add_code_block: Callable[[str], str],
        self_reflection_few_shot: Optional[str] = None,
) -> str:
    if model.is_chat:
        if self_reflection_few_shot is not None:
            messages = [
                Message(
                    role="system",
                    content=self_reflection_chat_instruction,
                ),
                Message(
                    role="user",
                    content=f'{self_reflection_few_shot}\n\n[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:',
                )
            ]
            reflection = model.generate_reflection(messages=messages)
            print(f'Self reflection output: {reflection}')
        else:
            messages = [
                Message(
                    role="system",
                    content=self_reflection_chat_instruction,
                ),
                Message(
                    role="user",
                    content=f'[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:',
                )
            ]
            reflection = model.generate_reflection(messages=messages)
    else:
        reflection = model.generate(
            f'{self_reflection_completion_instruction}\n{add_code_block(func)}\n\n{feedback}\n\nExplanation:')
    return reflection  # type: ignore


def sample_n_random(items: List[str], n: int) -> List[str]:
    """Sample min(n, len(items)) random items from a list"""
    assert n >= 0
    if n >= len(items):
        return items
    return random.sample(items, n)

def print_messages(system_message_text: str, user_message_text: str) -> None:
    print(f"""----↓↓↓↓↓↓↓↓↓↓↓↓↓------------------- SYSTEM MESSAGE --------↓↓↓↓↓↓↓↓↓↓↓↓↓---------------)
{system_message_text}
---------↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑------------
------↓↓↓↓↓↓↓↓↓↓↓↓----------------- USER MESSAGE --------↓↓↓↓↓↓↓↓↓↓↓↓---------------
{user_message_text}
--------↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑------------
""", flush=True)

def print_generated_func_body(func_body_str: str) -> None:
    print(f"""------↓↓↓↓↓↓↓↓↓↓↓↓↓-------- GENERATED FUNC BODY --------↓↓↓↓↓↓↓↓↓↓↓↓↓-------
{func_body_str}
---------↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑-----------""")
