import os
from builtins import getattr, exec
from dataclasses import dataclass

import torch
from torch.utils.cpp_extension import load

@dataclass
class Signature:
    namespace: str
    function_name: str
    parameters: str
    parameters_naked: str

def parse_file(cpp_file_path):
    namespace = ''
    signatures = []
    with open(cpp_file_path, 'r') as file:
        lines = file.read().splitlines()
        num_lines = len(lines)
        for line_id, line in enumerate(lines):
            if 'namespace' in line:
                namespace = line.split(' ')[1]
            if '#pragma torch_expose' in line:
                # It's possible that a function definition spans multiple lines.
                full_signature = ''

                for line_j in range(line_id + 1, num_lines):
                    signature = lines[line_j].strip()
                    # Multi-line signatures can sometimes contain comments.
                    if signature.startswith('/'):
                        continue
                    full_signature += lines[line_j]
                    if '{' in lines[line_j]: break

                full_signature = full_signature.replace('{', '').replace('\n', '')
                full_signature = full_signature.replace('  ', '')


                function_name_wrong = full_signature.split(' ')[1]
                delim = function_name_wrong.find('(')
                function_name = function_name_wrong[:delim]
                delim = full_signature.find('(')
                parameters = full_signature[delim:]

                parameters = parameters.replace('torch::', '').replace('s32', 'int').replace('u32', 'int').replace('s64', 'int').replace('&', '').replace('const', '')
                parameters = parameters.strip()

                parameters_naked = parameters.replace('Tensor', '').replace('int', '')

                parameters += ' -> ()'

                parameters = parameters.replace('Tensor out', 'Tensor(Y!) out')

                signatures.append(Signature(
                    namespace=namespace,
                    function_name=function_name,
                    parameters=parameters,
                    parameters_naked=parameters_naked
                ))
    return signatures

class Sources:
    def __init__(self, base_folder):
        self.base_folder = base_folder
        self.cpp_module = None

    def register_sources(self, module_name, exposed_sources, sources, cuda_flags, cpp_flags, global_dictionary):
        _sources = [str(os.path.join(self.base_folder, source)) for source in (exposed_sources + sources)]

        all_signatures = []
        for source in exposed_sources:
            signatures = parse_file(os.path.join(self.base_folder, source))
            all_signatures += signatures


        torch_def_macro = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){'

        for signature in all_signatures:
            torch_def_macro += f'm.def(\\"{signature.function_name}\\",&{signature.function_name},\\"\\");'
        torch_def_macro += '}'

        self.cpp_module = load(
            name=module_name,
            sources=_sources,
            extra_cflags=[f'{cpp_flags} -DTORCH_EXPOSE_DEFINITIONS="{torch_def_macro}"'],
            extra_cuda_cflags=[cuda_flags],
        )

        for source in exposed_sources:
            signatures = parse_file(os.path.join(self.base_folder, source))

            for signature in signatures:
                torch.library.define(f'{module_name}::{signature.function_name}', signature.parameters)

                torch.library.impl(f'{module_name}::{signature.function_name}', 'default', getattr(self.cpp_module, signature.function_name, None))
                exec(f'@torch.library.register_fake("{module_name}::{signature.function_name}")\n'
                     f'def {signature.function_name}_meta{signature.parameters_naked}: return')
                exec_locals = {}
                function_name = f'call_{signature.function_name}'
                exec(f'def {function_name}(*args): return torch.ops.{module_name}.{signature.function_name}(*args)', global_dictionary, exec_locals)
                func = exec_locals[function_name]
                print(f'exposing {function_name} as {func}')
                global_dictionary[function_name] = func

