import os
import sys
project_root = os.path.dirname(
    os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

import random
from synthesizer.template import TemplateFactory


class PlaceHolder:
    def __init__(self):
        self.name = "PLACEHOLDER"

    def __repr__(self):
        return f"PlaceHolder({self.name})"

    def __str__(self):
        return f"<{self.name}>"


PlaceHolder = PlaceHolder()


class ConstantExpression:
    def __init__(self, value, template_factory=None):
        self.value = value

    def __repr__(self):
        return f"ConstantExpression({self.value})"

    def __str__(self):
        return str(self.value)

    def nl(self):
        return str(self.value)

    def __eq__(self, other):
        if isinstance(other, ConstantExpression):
            return self.value == other.value
        return False

    def compute(self, value=None):
        return self.value

    def computation_process(self, value=None):
        return str(self.value)

    def computation_process_wo_result(self, value=None):
        return str(self.value)

    def parse_compute_args(self, args):
        return {}

    def substitute_entity(self, entity_map):
        return ConstantExpression(self.value)

    @classmethod
    def synthesize_from_interval(cls, interval):
        """
        Synthesize a ConstantExpression from an interval.
        :param interval: A tuple or list with two elements representing the interval.
        :return: A ConstantExpression object.
        """
        assert len(interval) == 2, "Interval must have exactly two elements."
        assert interval[0] <= interval[1], "Invalid interval: start must be less than or equal to end."
        value = random.randint(interval[0], interval[1])
        return cls(value)


class IdentityExpression:
    def __init__(self, entity, attribute, template_factory):
        self.entity = entity
        self.attribute = attribute
        if isinstance(template_factory, TemplateFactory):
            self.template = template_factory.get_template(
                "identity_expression")
        else:
            self.template = template_factory

    def __repr__(self):
        return f"IdentityExpression({self.entity}, {self.attribute})"

    def __str__(self):
        return f"{self.entity}[{self.attribute}]"

    def nl(self):
        return self.template.format(
            entity=self.entity,
            attribute=self.attribute
        )

    def __eq__(self, other):
        if isinstance(other, IdentityExpression):
            return (self.entity == other.entity and
                    self.attribute == other.attribute)
        return False

    def compute(self, value):
        return value

    def computation_process(self, value):
        return str(value)

    def computation_process_wo_result(self, value):
        return str(value)

    def parse_compute_args(self, args):
        assert args[self.entity][self.attribute] is not PlaceHolder, "IdentityExpression argument cannot be a placeholder."
        assert args[self.entity][self.attribute] is not None, "IdentityExpression argument cannot be None."
        return {"value": args[self.entity][self.attribute]}

    def substitute_entity(self, entity_map):
        return IdentityExpression(
            entity_map[self.entity],
            self.attribute,
            self.template
        )

    @classmethod
    def synthesize_from_interval(cls, entity, attribute, template_factory):
        """
        Synthesize an IdentityExpression from an interval.
        :param entity: The entity associated with the expression.
        :param attribute: The attribute associated with the expression.
        :return: An IdentityExpression object.
        """
        return cls(entity, attribute, template_factory)


class LinearExpression:
    def __init__(self, coefficient, bias, entity, attribute, template_factory):
        self.coefficient = coefficient
        self.bias = bias
        self.entity = entity
        self.attribute = attribute
        if isinstance(template_factory, TemplateFactory):
            self.template = template_factory.get_template(
                "linear_expression_negative" if self.bias < 0 else "linear_expression_positive")
        else:
            self.template = template_factory

    def __repr__(self):
        return f"LinearExpression({self.coefficient}, {self.bias}, {self.entity}, {self.attribute})"

    def __str__(self):
        if self.bias < 0:
            return f"{self.coefficient} * {self.entity}[{self.attribute}] - {-self.bias}"
        else:
            return f"{self.coefficient} * {self.entity}[{self.attribute}] + {self.bias}"

    def nl(self):
        return self.template.format(
            coefficient=self.coefficient,
            bias=self.bias,
            entity=self.entity,
            attribute=self.attribute
        )

    def __eq__(self, other):
        if isinstance(other, LinearExpression):
            return (self.coefficient == other.coefficient and
                    self.bias == other.bias and
                    self.entity == other.entity and
                    self.attribute == other.attribute)
        return False

    def compute(self, value):
        return self.coefficient * value + self.bias

    def computation_process(self, value):
        if self.bias < 0:
            return f"{self.coefficient} * {value} - {-self.bias} = {self.compute(value)}"
        else:
            return f"{self.coefficient} * {value} + {self.bias} = {self.compute(value)}"

    def computation_process_wo_result(self, value):
        if self.bias < 0:
            return f"{self.coefficient} * {value} - {-self.bias}"
        else:
            return f"{self.coefficient} * {value} + {self.bias}"

    def parse_compute_args(self, args):
        assert args[self.entity][self.attribute] is not PlaceHolder, "LinearExpression argument cannot be a placeholder."
        assert args[self.entity][self.attribute] is not None, "LinearExpression argument cannot be None."
        return {"value": args[self.entity][self.attribute]}

    def substitute_entity(self, entity_map):
        return LinearExpression(
            self.coefficient,
            self.bias,
            entity_map[self.entity],
            self.attribute,
            self.template
        )

    @classmethod
    def synthesize_from_interval(cls, interval, entity, attribute, template_factory):
        """
        Synthesize a LinearExpression from an interval.
        :param interval: A tuple or list with two elements representing the interval.
        :param entity: The entity associated with the expression.
        :param attribute: The attribute associated with the expression.
        :return: A LinearExpression object.
        """
        assert len(interval) == 2, "Interval must have exactly two elements."
        assert interval[0] <= interval[1], "Invalid interval: start must be less than or equal to end."
        coefficient = random.randint(interval[0], interval[1])
        bias = random.randint(interval[0], interval[1])
        return cls(coefficient, bias, entity, attribute, template_factory)


class BinaryExpression:
    registered_operations = {
        "max": max,
        "min": min,
        "addition": lambda x, y: x + y,
        "subtraction": lambda x, y: x - y,
    }

    def __init__(self, expr1, expr2, operation, template_factory):
        self.expr1 = expr1
        self.expr2 = expr2
        self.operation = operation
        if isinstance(template_factory, TemplateFactory):
            self.template = template_factory.get_template(
                f"binary_expression_{operation}")
        else:
            self.template = template_factory
        if operation not in self.registered_operations:
            raise ValueError(
                f"Operation '{operation}' is not registered. Available operations: {list(self.registered_operations.keys())}")
        self.operation_func = self.registered_operations[operation]

    def __repr__(self):
        return f"BinaryExpression({self.operation}, {repr(self.expr1)}, {repr(self.expr2)})"

    def __str__(self):
        return f"{self.operation}({str(self.expr1)}, {str(self.expr2)})"

    def nl(self):
        return self.template.format(
            expr1=self.expr1.nl(),
            expr2=self.expr2.nl()
        )

    def __eq__(self, other):
        if isinstance(other, BinaryExpression):
            return (self.operation == other.operation and
                    self.expr1 == other.expr1 and
                    self.expr2 == other.expr2)
        return False

    def compute(self, value1, value2):
        return self.operation_func(self.expr1.compute(value1), self.expr2.compute(value2))

    def computation_process(self, value1, value2):
        return f"{self.operation}({self.expr1.computation_process_wo_result(value1)}, {self.expr2.computation_process_wo_result(value2)}) = {self.compute(value1, value2)}"

    def parse_compute_args(self, args):
        try:
            args1 = self.expr1.parse_compute_args(args)
            args2 = self.expr2.parse_compute_args(args)
        except AssertionError as e:
            raise AssertionError(
                f"BinaryExpression requires valid values for both expressions: {e}")
        return {"value1": args1.get("value", None), "value2": args2.get("value", None)}

    def substitute_entity(self, entity_map):
        return BinaryExpression(
            self.expr1.substitute_entity(entity_map),
            self.expr2.substitute_entity(entity_map),
            self.operation,
            self.template
        )

    @classmethod
    def synthesize_from_interval(cls, expr1, expr2, template_factory):
        """
        Synthesize a BinaryExpression from two expressions.
        :param expr1: The first expression.
        :param expr2: The second expression.
        :param template: The template for natural language representation.
        :return: A BinaryExpression object.
        """
        operation = random.choice(list(cls.registered_operations.keys()))
        assert not (isinstance(expr1, BinaryExpression) or isinstance(
            expr2, BinaryExpression)), "Nested BinaryExpressions are not allowed."
        return cls(expr1, expr2, operation, template_factory)
