import sys
from einx._src.util.rwlock import RWLock
import traceback

class InvalidBackendException(Exception):
    def __init__(self, message):
        self.message = message

class InvalidBackend:
    def __init__(self, name, message, priority=0):
        self.name = name
        self.message = message
        self.priority = priority

    def __getattr__(self, name):
        raise InvalidBackendException(self.message)

    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

class Backend:
    def __init__(self, ops, name, priority, optimizations, compiler, tensor_type):
        self.ops = ops
        self.name = name
        self.priority = priority
        self.optimizations = optimizations
        self.compiler = compiler
        self.tensor_type = tensor_type

    def __getattr__(self, name):
        return self.ops[name]

    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    def __enter__(self):
        Use(self, registry).__enter__()

    def __exit__(self, *args):
        Use(self, registry).__exit__(*args)

class OperationNotSupportedException(Exception):
    def __init__(self, message=None):
        if message is None:
            message = "This operation is not supported by this backend."
        self.message = message

class Use:
    def __init__(self, backend, registry):
        self.backend = backend
        self.registry = registry

    def __enter__(self):
        with self.registry.use_rwlock.write():
            self.registry.use_stack.append(self.backend)
    
    def __exit__(self, exc_type, exc_value, traceback):
        with self.registry.use_rwlock.write():
            assert id(self.registry.use_stack[-1]) == id(self.backend)
            self.registry.use_stack.pop()

class BackendRegistry:
    def __init__(self):
        self.seen_modules_rwlock = RWLock()
        self.seen_module_names = set()

        self.backend_rwlock = RWLock()

        self.uninitialized_backends = {} # module name -> (backend factory, tensor-types)

        self.backends = []
        self.tensortype_to_backend = {}
        self.name_to_backend = {}

        self.use_stack = []
        self.use_rwlock = RWLock()

    def _get_new_module_names(self):
        # Acquire read-access to check if there are new modules
        has_new_modules = False
        with self.seen_modules_rwlock.read():
            for module_name in sys.modules:
                if module_name not in self.seen_module_names:
                    has_new_modules = True

        if has_new_modules:
            # Acquire write-access only if there are new modules
            new_module_names = []
            with self.seen_modules_rwlock.write():
                for module_name in sys.modules:
                    if module_name not in self.seen_module_names:
                        new_module_names.append(module_name)
                self.seen_module_names.update(new_module_names)
            return new_module_names
        else:
            return []

    # Check if any new modules have been imported and construct all backends that have been
    # registered for them
    def _update(self):
        module_names = self._get_new_module_names()
        if len(module_names) > 0:
            # This is only called few times, so we can afford to acquire the write lock
            with self.backend_rwlock.write():
                for module_name in module_names:
                    if module_name in self.uninitialized_backends:
                        for backend_factory in self.uninitialized_backends[module_name]:
                            self._run_factory(module_name, backend_factory)
                        del self.uninitialized_backends[module_name]

    def _run_factory(self, module_name, backend_factory):
        try:
            backend, tensor_types = backend_factory()
        except Exception as e:
            backend = InvalidBackend(
                f"error-backend-for-{module_name}",
                f"Failed to import backend for module {module_name} due to the following error:\n{traceback.format_exc()}",
            )
            tensor_types = []
        self.register(backend, tensor_types)

    def _invalid_backend_reasons(self):
        invalid_backends = [backend for backend in self.backends if isinstance(backend, InvalidBackend)]
        if len(invalid_backends) > 0:
            message = "\n\nThe following backends could not be initialized:\n"
            for backend in invalid_backends:
                message += f"\n############################ {backend.name} ############################\n"
                message += f"{backend.message}\n"
        else:
            message = ""
        return message

    def register(self, backend, tensor_types=[]):
        if not isinstance(backend, (Backend, InvalidBackend)):
            raise ValueError("Backend must be an instance of Backend or InvalidBackend.")
        with self.backend_rwlock.write():
            self.backends.append(backend)
            self.name_to_backend[backend.name] = backend
            for tensor_type in tensor_types:
                if tensor_type in self.tensortype_to_backend:
                    raise ValueError(f"Backend {self.tensortype_to_backend[tensor_type]} has already been registered for tensor type {tensor_type}.")
                self.tensortype_to_backend[tensor_type] = backend

    def register_on_import(self, module_name, backend_factory):
        with self.backend_rwlock.write():
            if module_name in sys.modules:
                # Module is already imported -> register backend now
                self._run_factory(module_name, backend_factory)
            else:
                # Module is not yet imported -> register factory
                if module_name not in self.uninitialized_backends:
                    self.uninitialized_backends[module_name] = []
                self.uninitialized_backends[module_name].append(backend_factory)

    def get_by_name(self, name, update=True):
        if update:
            self._update()
        with self.backend_rwlock.read():
            if name not in self.name_to_backend:
                raise ValueError(f"Backend with name {name} not found. Currently registered backends are: {list(self.name_to_backend.keys())}{self._invalid_backend_reasons()}")
            return self.name_to_backend.get(name)

    def get_by_tensortype(self, tensor_type, update=True):
        if update:
            self._update()
        with self.backend_rwlock.read():
            if tensor_type in self.tensortype_to_backend:
                return self.tensortype_to_backend.get(tensor_type)
            else:
                # Check if the tensor type is a subclass of any registered tensor types
                backend = None
                for registered_tensor_type in self.tensortype_to_backend:
                    if issubclass(tensor_type, registered_tensor_type):
                        backend = self.tensortype_to_backend[registered_tensor_type]
                if backend is None:
                    return None

        # Type is subclass of registered tensor type -> update the mapping
        with self.backend_rwlock.write():
            self.tensortype_to_backend[tensor_type] = backend
        return backend

    def get_by_tensortypes(self, tensor_types, update=True):
        if update:
            self._update()
        backends = [self.get_by_tensortype(tensor_type, update=False) for tensor_type in tensor_types]
        backends = [backend for backend in backends if backend is not None]

        # Remove duplicates
        backends = list(set(backends))

        # Keep only backends with highest priority
        if len(backends) > 1:
            max_priority = max([backend.priority for backend in backends])
            backends = [backend for backend in backends if backend.priority == max_priority]

        return backends

    def get(self, backend=None, tensor_types=[], update=True):
        # If backend object is given
        if isinstance(backend, (Backend, InvalidBackend)):
            return backend

        if update:
            self._update()

        # If backend name is given
        if isinstance(backend, str):
            return self.get_by_name(backend, update=False)

        # If global default backend is specified using einx.backend.use
        with self.use_rwlock.read():
            if len(self.use_stack) > 0:
                return self.use_stack[-1]

        # Other backend parameters are invalid
        if backend is not None:
            raise ValueError(
                "Backend must be either a Backend instance, a string, or None."
            )

        # If no backend is specified, determine backend from tensor types
        backends = self.get_by_tensortypes(tensor_types, update=False)
        if len(backends) == 1:
            return backends[0]
        elif len(backends) > 1:
            raise ValueError(
                "Failed to determine which backend to use for this operation:\n"
                " - The 'backend' parameter is not specified.\n"
                " - No global default backend is not specified using 'einx.backend.use'.\n"
                " - Multiple registered backends match the tensor types of the arguments: "
                f"{', '.join([backend.name for backend in backends])}"
            )
        else:
            message = (
                "Failed to determine which backend to use for this operation:\n"
                " - The 'backend' parameter is not specified.\n"
                " - No global default backend is not specified using 'einx.backend.use'.\n"
                f" - No registered backends match the tensor types of the arguments: {', '.join([str(t) for t in tensor_types])}"
            )
            message += self._invalid_backend_reasons()

            raise ValueError(message)

registry = BackendRegistry()