import autogen
from opto.trace import node, bundle, model
import random

class LLMCallable:
    def __init__(self, config_list=None, max_tokens=16383, verbose=False):
        if config_list is None:
            config_list = autogen.config_list_from_json("OAI_CONFIG_LIST")
        self.llm = autogen.OpenAIWrapper(config_list=config_list)
        self.max_tokens = max_tokens
        self.verbose = verbose

    @bundle(catch_execution_error=True)
    def call_llm(self, user_prompt):
        system_prompt = "You are a helpful assistant.\n"
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        response = self.llm.create(messages=messages, max_tokens=self.max_tokens)
        response = response.choices[0].message.content

        if self.verbose:
            print("LLM response:\n", response)
        return response

# this is added to Trace, but not publicly usable yet
@bundle()
def str_join(x, *y):
    return x.join(y)

@model
class DSLMapperGenerator(LLMCallable):
    """
    A class to generate DSL mappers by modeling decision points.
    Each method corresponds to a decision point and generates code statements.
    """

    def __init__(self, tasks, regions, index_tasks, single_tasks, index_task_specification):
        self.tasks = tasks
        self.regions = regions
        self.index_tasks = index_tasks
        self.single_tasks = single_tasks
        self.index_task_specification = index_task_specification

    @bundle(trainable=True)
    def task_decision(self, tasks):
        """
        Generate Task mapping statements.

        Every statement in the program should end with `;` (like `C`)
        Functions should start with `def FuncName(Arg1Type arg1, ...)` and the function body needs to be wrapped with `{ ... }`. Functions will be used in the IndexTaskMap statement.
        Comments should start with `#` (like `Python`)
        Please do not use `//` for comments, use `#` instead.
        Example generated code:
        ```
        Task foo CPU; # for task named "foo", run on CPU
        Task * GPU,OMP,CPU; # for any other task, by default try running on GPU first (if a GPU variant exists), then try OpenMP Processor, finally try CPU.
        ```
        The task named `foo` will use `CPU`. For other tasks, they will use the default (fallback) strategies.
        The wildcard `*` is a way to describe fallback policies, after which is a priority-list of processors kinds.
        The supported processor kinds are: `CPU`, `GPU`, `OMP`.
        If the task is super small or GPU runs out of memory, then you could try moving the tasks outside GPU.
        For most cases, just generating `Task * GPU,OMP,CPU;` is enough.
        """
        code_statements = ['Task * GPU,OMP,CPU;']  # Default statement
        for task in tasks:
            # Decide whether to include the statement to move tasks on CPU or OMP, you can change the weights
            include_statement = random.choices([True, False], weights=[50, 50], k=1)[0]
            if include_statement:
                # Decide the processor type (OMP or CPU), you can change the weight
                processor = random.choices(['OMP', 'CPU'], weights=[40, 60], k=1)[0]
                # Generate the code statement
                statement = f'Task {task} {processor};'
                code_statements.append(statement)

        return code_statements

    @bundle(trainable=True)
    def region_decision(self, regions):
        """
        Generate Region mapping statements.
        Default strategy code:
        ```
        Region * * GPU FBMEM; # for any tasks (first *), any regions (second *), if mapped onto GPU, use GPU FrameBuffer Memory as default
        Region * * * SOCKMEM,SYSMEM; # for any tasks, any regions, if mapped onto CPU or OpenMP, use socket memory or system memory as default
        ```

        Application-specific strategy code:
        ```
        Region * region_name1 GPU ZCMEM;
        Region * region_name2 GPU ZCMEM;
        ```
        The first is the task name, here "*" means "for any tasks" (that use those regions). The second argument is the region names.
        Here the region_name1 and region_name2 are the names of the regions. They will be mapped to GPU ZeroCopy memory, overwriting the default strategy.
        If the GPU FBMEM runs out of memory, then consider moving some regions to GPU ZCMEM using `Region * region_name GPU ZCMEM;`.
        """
        code_statements = ['Region * * GPU FBMEM;', 'Region * * * SOCKMEM,SYSMEM;']  # Default statements, please do not modify this.
        for region in regions:
            # Decide whether to include the statement, feel free to change the weights here, or even make it deteministically true or false. If false, then all regions will be placed on to GPU FBMEM.
            include_statement = random.choices([True, False], weights=[90, 10], k=1)[0]
            if include_statement:
                processor = 'GPU'
                memory = 'ZCMEM'
                # Generate the code statement
                statement = f'Region * {region} {processor} {memory};'
                code_statements.append(statement)
        return code_statements

    @bundle(trainable=True)
    def layout_decision(self):
        """
        Generate Layout mapping statements.
        Examples:
        ```
        Layout * * * C_order SOA;
        ```
        The above code suggests that all data should be in C_order and Struct-Of-Array (SOA) layout.
        Another example is:
        ```
        Layout * * * F_order AOS;
        ```
        You should list all the constraints in the Layout statement, e.g., `Layout * * * Align==128 F_order` specifies that the memory should align to 128 bytes while using Fortrain order.
        """
        code_statements = []
        # Decide whether to include the statement, you can change the weights here
        include_statement = random.choices([True, False], weights=[50, 50], k=1)[0]
        if include_statement:
            # Decide layout parameters, you can change the weights here
            order = random.choices(['F_order', 'C_order'], weights=[50, 50], k=1)[0]
            structure = random.choices(['SOA', 'AOS'], weights=[50, 50], k=1)[0]
            alignment = random.choices(['', 'Align==64', 'Align==128'], weights=[50, 25, 25], k=1)[0]
            # Generate the code statement
            components = ['Layout', '*', '*', '*', order, structure]
            if alignment:
                components.append(alignment)
            statement = ' '.join(components) + ';'
            code_statements.append(statement)
        return code_statements

    @bundle(trainable=True)
    def instance_limit_decision(self, tasks):
        """
        Generate InstanceLimit mapping statements.
        Example:
        ```
        InstanceLimit task_name 1;
        ```
        On each node, only one `task_name` can be mapped at the same time.
        Typically, tasks are mapped ahead of time, and using InstanceLimit can avoid mapping too many tasks at the same time ahead of time, which can avoid consuming too much memory.
        This statement is useful only when some hanging happens. Otherwise, please do not use this.
        """
        code_statements = []
        if len(tasks) > 1: # you cannot use random.choice on None
            task = random.choice(tasks)
            # It's almost never used, so we set the probability of including the statement to be low
            # Feel free to change the weights here, or even make it deteministically true or false
            include_statement = random.choices([True, False], weights=[20, 80], k=1)[0]
            if include_statement:
                limit_value = random.randint(1, 2)
                # Generate the code statement. Please do not generate it unless you know what you are doing.
                statement = f'InstanceLimit {task} {limit_value};'
                code_statements.append(statement)
        return code_statements

    @bundle(trainable=True)
    def index_task_map_decision(self, index_tasks, index_task_specification):
        """
        The function index_task_map_decision generates IndexTaskMap mapping statements.
        The variable `index_tasks` is a list of index launch tasks that need to be mapped.
        The variable `index_task_specification` is a string that contains the specification of the index tasks.
        In general, index task map specifies how to map the index tasks to the processor space.
        The index tasks are tasks that are launched with an index space, and the index space is a set of points.
        Every point in the processor space specifies a processor.
        So this function can decide where to put each point task of the index tasks in the processor space.

        The index tasks are a set of point tasks that are launched with an index space.
        Each point task is task.ipoint (a multi-dimensional tuple), and the index space is task.ispace (a multi-dimensional tuple).
        If the index space is (2, 2), then task.ispace[0] = 2, task.ispace[1] = 2.
        And task.ipoint will be (0, 0), (0, 1), (1, 0), (1, 1) as the iteration space.
        The processor space is initialized as a 2-dim tuple with respect to a processor kind, e.g., mcpu = Machine(GPU);

        Example generated code:
        ```
        mgpu = Machine(GPU);
        def block1d(Task task) {
            ip = task.ipoint;
            return mgpu[ip[0] % mgpu.size[0], ip[0] % mgpu.size[1]];
        }
        IndexTaskMap task1 block1d;
        ```
        Please note that the first index of mgpu should end with % mgpu.size[0], and the second element should end with %mgpu.size[1]
        In this way, the returned point is always within the processor space.
        """

        # Base statement that is always included
        # Please do not use or generate m_2d in index_task_map_decision function.
        code_statements = ['mgpu = Machine(GPU);'] 
        
        #  Function definitions
        # Note that the functions in the generated DSL are in this format: def func_name(Task task) { function_body here }
        # There should be no `:` in the generated program. The generated code is not Python, it's a DSL.
        # Also, there is no ; after the final }
        # Only task.ipoint is allowed to be used. There are no other fields in the Task object.
        function_definition1 = '''
    def block1d(Task task) {
        ip = task.ipoint;
        return mgpu[ip[0] % mgpu.size[0], ip[0] % mgpu.size[1]];
    }
    '''
        function_definition2 = '''
    def cyclic1d(Task task) {
        ip = task.ipoint;
        linearize = ip[0] * 2 + ip[1];
        return mgpu[ip[0] % mgpu.size[0], linearize % mgpu.size[1]];
    }
    '''
        function_definition3 = '''
    def cyclic2d(Task task) {
        ip = task.ipoint;
        linearize = ip[0] + ip[1] * 2;
        return mgpu[ip[0] % mgpu.size[0], linearize % mgpu.size[1]];
    }
    '''
        function_definition4 = '''
    def linearize3D(Task task) {
        ip = task.ipoint;
        linearize = ip[0] + ip[1] + ip[2];
        return mgpu[linearize % mgpu.size[0], linearize % mgpu.size[1]];
    }
    '''
        function_definition5 = '''
    def linearize2D(Task task) {
        ip = task.ipoint;
        linearize = ip[0] * 2 + ip[2];
        return mgpu[linearize % mgpu.size[0], linearize % mgpu.size[1]];
    }
    '''

        function_definitions = [function_definition1, function_definition2, function_definition3, function_definition4, function_definition5]
        used_definitions = set()

        # Iterate over each task and randomly decide to generate an IndexTaskMap statement
        for task in index_tasks:
            # feel free to change the weights here, or even make it deteministic
            include_statement = random.choices([True, False], weights=[10, 90], k=1)[0]
            if include_statement:
                # Randomly select one of the 5 function definitions above with weights, feel free to change the weights here
                chosen_function_definition = random.choices(function_definitions, weights=[0.2, 0.2, 0.2, 0.2, 0.2], k=1)[0]

                # Please do not modify this line
                if chosen_function_definition not in used_definitions:
                    code_statements.append(chosen_function_definition)
                    used_definitions.add(chosen_function_definition)

                # Extract the function name from the chosen definition. Do NOT modify this line!
                function_name = chosen_function_definition.split('(')[0].strip().split()[-1]

                # Generate the code statement for the task using the chosen function
                statement = f'IndexTaskMap {task} {function_name};'
                code_statements.append(statement)

        return code_statements

    @bundle(trainable=True)
    def single_task_map_decision(self, single_tasks):
        """
        Generate SingleTaskMap mapping statements.
        ```
        m_2d = Machine(GPU);
        def same_point(Task task) {
            return m_2d[*task.parent.processor(m_2d)];
        }
        SingleTaskMap task_4 same_point;
        ```
        The above code suggests that the task named `task_4` will be mapped to the same point as its parent task.
        """
        code_statements = []
        # Please pay close attention on how to write multi-line strings in Python, and the function definition syntax.
        # Please do not modify the following function_definition, also please do not write your own function definition for SingleTaskMap.
        function_definition = '''
m_2d = Machine(GPU);
def same_point(Task task) {
    return m_2d[*task.parent.processor(m_2d)];
}
'''
        code_statements.append(function_definition)
        for task in single_tasks:
            # Decide whether to include the statement, you can change the weights here
            include_statement = random.choices([True, False], weights=[10, 90], k=1)[0]
            if include_statement:
                # Generate the code statement
                statement = f'SingleTaskMap {task} same_point;'
                code_statements.append(statement)
        return code_statements

    def generate_mapper(self):
        """
        Generate the final mapper code by combining all code statements.
        """
        task_statements = self.task_decision(self.tasks)
        region_statements = self.region_decision(self.regions)
        layout_statements = self.layout_decision()
        instance_limit_statements = self.instance_limit_decision(self.tasks)
        index_task_map_statements = self.index_task_map_decision(self.index_tasks, self.index_task_specification)
        single_task_statements = self.single_task_map_decision(self.single_tasks)

        code_statements = (
            task_statements +
            region_statements +
            layout_statements +
            instance_limit_statements +
            index_task_map_statements +
            single_task_statements
        )

        # Combine all code statements and function definitions into a single string
        # mapper_code = '\n'.join(code_statements + function_definitions)
        code_list = code_statements
        mapper_code = str_join(node('\n'), *code_list)
        return mapper_code


# Example usage
if __name__ == "__main__":
    generator = DSLMapperGenerator(["task1", "task2"], ["region1", "region2"], ["index_task1"], ["single_task1"], "")
    mapper_code = generator.generate_mapper()
    print("Generated Mapper Code:")
    print(mapper_code)
    print(mapper_code.data)
    fig = mapper_code.backward(visualize=True)
    fig.view()