from dataclasses import dataclass, field
from typing import TypeVar, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, Any
from collections import defaultdict

from src.ds.causal_graph import CausalGraph
from src.scm.scm import SCM
from src.scm.distribution.distribution import Distribution
from src.scm.distribution.discrete_distribution import BernoulliDistribution
import copy

T = TypeVar("T")
RANDOM_STATE = 42

# ----- Template-level objects -----

@dataclass(frozen=True)
class TemplateEntity:
    name: str # e.g., "Student", "Course"

@dataclass(frozen=True)
class TemplateRelation:
    name: str # e.g., "Takes"
    signature: Tuple[TemplateEntity, ...]   # participating entity types, e.g., (Student, Course)

@dataclass(frozen=True)
class TemplateAttr:
    name: str            # e.g., "Intelligence"
    owner: Union[TemplateEntity, TemplateRelation] # e.g., "Student"
    # how to encode possible values?

@dataclass(frozen=True)
class RelationalSchema:
    entities: List[TemplateEntity]
    relations: List[TemplateRelation]  
    attributes: List[TemplateAttr]  # attribute name -> owning entity or relation type
    exogenous_attributes: Optional[List[TemplateAttr]] = None


TemplateObj = Union[TemplateEntity, TemplateRelation]
# ----- Ground-level objects -----
@dataclass(frozen=True)
class GroundEntity:
    """
    A concrete entity instance with its grounded attributes.
    """

    type: TemplateEntity
    name: str # e.g., Bob


@dataclass(frozen=True)
class GroundRelation:
    """
    A concrete relation tuple with its grounded attributes.
    """

    type: TemplateRelation # e.g., Takes
    signature: Tuple[GroundEntity, ...] # e.g., (Bob, CS101)
 
GroundObj = Union[GroundEntity, GroundRelation]

@dataclass(frozen=True)
class GroundAttribute:
    """
    Concrete ground attribute representing one attribute on a specific object or relation tuple.
    """
    type: TemplateAttr
    owner: GroundObj
  
    @property
    def name(self) -> str:
        if isinstance(self.owner, GroundEntity):
            owner_repr = self.owner.name  # e.g. S1
        else:
            rel = self.owner.type.name    # e.g. Nearest
            args = ",".join(ent.name for ent in self.owner.signature)  # e.g. S1,L2
            owner_repr = f"{rel}({args})"  # e.g. Nearest(S1,L2)
        return f"{owner_repr}.{self.type.name}"

    def __str__(self) -> str:
        return self.name


@dataclass(frozen=True)
class RelationalSkeleton:
    schema: RelationalSchema
    entities: List[GroundEntity]
    relations: List[GroundRelation]

@dataclass(frozen=True)
class RelationalInstance:
    skel: RelationalSkeleton
    ground_attrs: Dict[GroundObj, List[GroundAttribute]] = field(default_factory=dict)
    vals: Dict[GroundAttribute, Any] = field(default_factory=dict)

    def get(self, ga: GroundAttribute, default: Any = None) -> Any:
        return self.vals.get(ga, default)



def validate_skeleton(schema: RelationalSchema, skel: RelationalSkeleton):
    """
    Validates a RelationalSkeleton against a RelationalSchema.

    This function ensures that all ground entities and relations in the skeleton
    conform to the schema definitions, including type constraints and signature matching.

    Args:
        schema (RelationalSchema): The relational schema that defines valid entity types,
            relation types, and their signatures.
        skel (RelationalSkeleton): The relational skeleton to validate, containing
            ground entities and relations.

    Returns:
        bool: True if the skeleton is valid.

    Raises:
        ValueError: If any of the following conditions are violated:
            - A ground entity's type is not defined in the schema
            - A ground relation's type is not defined in the schema
            - A ground relation's arity (number of entities) doesn't match the
              relation type's signature in the schema
            - A ground entity in a relation doesn't match the expected type
              specified in the relation type's signature
    """
    # Check that the
    schema_entities = {e for e in schema.entities}
    schema_relations = {r for r in schema.relations}

    # entities: type must be in schema
    for ge in skel.entities:
        if ge.type not in schema_entities:
            raise ValueError(f"Ground entity {ge.name} has unknown type {ge.type}")

    # relations: type must be in schema and tuple types must match signature
    for gr in skel.relations:
        if gr.type not in schema_relations:
            raise ValueError(f"Ground relation {gr} has unknown type {gr.type}")
        if len(gr.signature) != len(gr.type.signature):
            raise ValueError(f"Relation {gr.type.name} arity mismatch")
        for ground_ent, expected_cls in zip(gr.signature, gr.type.signature):
            if ground_ent.type is not expected_cls:
                raise ValueError(
                    f"Relation {gr.type.name} expected {expected_cls} but got {ground_ent.type} in {gr}"
                )
    return True

def gen_ground_attributes(skel: RelationalSkeleton):
    """
    Given a relational schema and a relational skeleton, construct ground attributes.
    Returns (node_name -> attribute_name, attribute_name -> [node_names], [GroundAttribute]).
    Node names look like "<constant>.<attr>" for entity attributes and
    "(c1,...,ck).<attr>" for relation attributes.
    """
    schema = skel.schema
    ground_attrs:  DefaultDict[GroundObj, List[GroundAttribute]] = defaultdict(list)
    for obj in skel.entities + skel:
        for attr in schema.attributes:
            owner = attr.owner
            if obj.type is owner:
                ground_attrs[obj].append(GroundAttribute(type=attr, owner=obj))
    return ground_attrs

# ----- Causal structure -----
@dataclass(frozen=True)
class RelationalConstraint:
    const: Callable[[RelationalSkeleton,
                    GroundObj, GroundObj],
                    bool]

    def __call__(
        self,
        skel: RelationalSkeleton,
        a: GroundObj,
        b: GroundObj
    ) -> bool:
        return self.const(skel, a, b)
    

@dataclass(frozen=True)
class Aggregator:
    """
    Aggregator is a permutation-invariant function wrapper that applies an aggregation 
    function to a list of tuples. It ensures that the aggregation operation is 
    deterministic and independent of the order of the input elements.
    Attributes:
        agg (Callable[[List[List[T]]], List[T]]): 
            A callable that takes a list of tuples as input and returns a single tuple 
            as the aggregated result.
    Methods:
        __call__(xs: List[List[T]]) -> List[T]:
            Applies the aggregation function `agg` to the input list of tuples and 
            returns the aggregated result.
    Example:
        >>> from statistics import mean
        >>> aggregator = Aggregator(agg=lambda xs: tuple(mean(x) for x in zip(*xs)))
        >>> result = aggregator([(1, 2), (3, 4), (5, 6)])
        >>> print(result)
        (3.0, 4.0)
    """
    agg: Callable[[List[List[T]]], List[T]]

    def __call__(self, xs: List[List[T]]) -> List[T]:
        return self.agg(xs)


@dataclass(frozen=True)
class RelationalRole:
    name: str 
    pa: List[TemplateAttr] # must all belong to the same type of object
    const: RelationalConstraint
    agg: Aggregator

    def __call__(self,
                inst: RelationalInstance,
                target: GroundObj) -> List[List[T]]:

        skel = inst.skel
        ground_attrs = inst.ground_attrs
        pa_vars = []
        for obj, attrs in ground_attrs.items():
            if not obj.type == self.pa[0].owner: # all attributes must belong to same type object
                continue
            if obj is target:
                continue
            if self.const(skel, obj, target):
                print(f"RelationalRole matched: {target} -- {obj}")
                obj_pa = []
                for attr in attrs:
                    if attr.type in self.pa:
                        obj_pa.append(inst.get(attr))
                pa_vars.append(obj_pa)

        print("aggregated variable values", pa_vars)

        return self.agg(pa_vars)



@dataclass(frozen=True)
class TemplateRSF:
    # direct endogenous parents (same owner type)
    pa: dict[TemplateAttr, Any]
    # direct latent parents (same owner type)
    u: dict[TemplateAttr, Any]
    # endogenous roles (each role.agg returns some value)
    pa_roles: dict[TemplateAttr, Any]
    # exogenous roles
    u_roles: dict[TemplateAttr, Any]

    # f takes exactly what the grounded equation will compute
    f: Callable[
        [Tuple[Any, ...], Tuple[Any, ...], Tuple[Any, ...], Tuple[Any, ...]], #pa_V, u_V, pa+_V, u+_V
        Any
    ]

    def __call__(self,
                 inst: RelationalInstance,
                 target: GroundObj) -> Any:
        # get values for direct parents
        pa_V_vals = {
            attr.name: inst.get(GroundAttribute(type=attr, owner=target))
            for attr in self.pa
        }
        # get values for latent parents
        u_V_vals = {
            attr.name: inst.get(GroundAttribute(type=attr, owner=target))
            for attr in self.u
        }
        # get values for relational endogenous parents
        pa_plus_V_vals = {
            role.name: role(inst, target)
            for role in self.pa_roles
        }
        # get values for relational latent parents
        u_plus_V_vals = {
            role.name: role(inst, target)
            for role in self.u_roles
        }
        
        return self.f(pa_V_vals, u_V_vals, pa_plus_V_vals, u_plus_V_vals)



class TemplateRSCM():
    """
    Relational extension of SCM that grounds a schema/skeleton into a standard
    SCM while tying parameters by attribute class.
    """

    schema: RelationalSchema
    # V: List[TemplateAttr]
    U: List[TemplateAttr]
    F: Dict[TemplateAttr, TemplateRSF]
    PU: Distribution

    def __init__(
        self,
        schema: RelationalSchema,
        U: List[TemplateAttr],
        F: Dict[TemplateAttr, TemplateRSF],
        PU: Dict[TemplateObj, Distribution],
    ):
        # Create a copy of the schema but include the exogenous attributes
        self.schema = RelationalSchema(
            entities=schema.entities,
            relations=schema.relations,
            attributes=schema.attributes,
            exogenous_attributes=U,
        )
        self.V = self.schema.attributes
        self.U = U
        self.F = F
        self.PU = PU

    def get_ground_rscm(self, skeleton):
    
        # ground the schema/skeleton into a standard SCM
        ground_attrs = gen_ground_attributes(skeleton)
        U_ground  = []
        V_ground = []
        F_ground = {}
        PU_ground = {}
        for obj, attrs in ground_attrs.items():
            PU_ground[obj.name] = copy.deepcopy(self.PU[obj.type])
            for attr in attrs:
                if attr.type in self.U:
                    U_ground.append(attr.name)
                else:
                    V_ground.append(attr.name)
                    # ground the RSF for this attribute on this object
                    F_ground[attr.name] = copy.deepcopy(self.F[attr.type])

        # construct standard SCM
        rscm = SCM(
            v=V_ground,
            f=F_ground,
            pu=PU_ground,
        )
        return rscm



if __name__ == "__main__":
    
    # =========================
    # Schema (types)
    # =========================
    Light  = TemplateEntity("Light")
    Bar    = TemplateEntity("Bar")
    Sprite = TemplateEntity("Sprite")

    OnTopOf  = TemplateRelation("OnTopOf",  signature=(Bar, Sprite))
    Nearest  = TemplateRelation("Nearest",  signature=(Sprite, Light))

    SColor = TemplateAttr("SColor", owner=Sprite)
    Shape  = TemplateAttr("SShape", owner=Sprite)
    LColor = TemplateAttr("LColor", owner=Light)


    # Exogenous attributes
    U_l = TemplateAttr("U_l", owner=Light)
    U_s = TemplateAttr("U_s", owner=Sprite)
    U_sc = TemplateAttr("U_sc", owner=Sprite)
    U_ss = TemplateAttr("U_ss", owner=Sprite)


    schema = RelationalSchema(
        entities=[Light, Bar, Sprite],
        relations=[OnTopOf, Nearest],
        attributes=[SColor, Shape, LColor],
        # attributes=[SColor, Shape, LColor, U_l, U_s, U_sc, U_ss],
    )

    # =========================
    # Skeleton (instances)
    # =========================
    # Ground entities
    L1 = GroundEntity(Light,  "L1")
    L2 = GroundEntity(Light,  "L2")

    B1 = GroundEntity(Bar,    "B1")
    B2 = GroundEntity(Bar,    "B2")

    S1 = GroundEntity(Sprite, "S1")
    S2 = GroundEntity(Sprite, "S2")
    S3 = GroundEntity(Sprite, "S3")

    # Ground relations
    R1 = GroundRelation(OnTopOf,  signature=(B1, S1))
    R2 = GroundRelation(OnTopOf,  signature=(B2, S2))
    R3 = GroundRelation(Nearest,  signature=(S1, L1))
    R4 = GroundRelation(Nearest,  signature=(S2, L2))
    R5 = GroundRelation(Nearest,  signature=(S3, L2))
    R6 = GroundRelation(Nearest,  signature=(S1, L2))

    skel = RelationalSkeleton(
        schema=schema,
        entities=[L1, L2, B1, B2, S1, S2, S3],
        relations=[R1, R2, R3, R4, R5, R6],
    )

    # print("Is skeleton valid?")
    # print(validate_skeleton(schema, skel))

    # print("Grounded variables:")
    # ground_attrs = ground_variables(skel)
    # for ga in ground_attrs:
    #     print(ga)

    # def lights(schema, skel, light, sprite):
    #     # light is nearest to sprite
    #     for e in skel.entities:
    #         if e.type == Bar and GroundRelation(OnTopOf, signature=(e, sprite)) in skel.relations:
    #             return False
    #     if GroundRelation(Nearest, signature=(sprite, light)) in skel.relations:
    #         return True 

    # print("Testing relational constraint:")
    # rc = RelationalConstraint(
    #     arg1=Light,
    #     arg2=Sprite,
    #     const=lights,
    # )
    # print(rc(skel, L1, S1))  # False
    # print(rc(skel, L2, S2))  # False
    # print(rc(skel, L2, S3))  # True
    # print(rc(skel, L1, S2))  # False

    # Testing aggregator

    # Testing relational role
    def near_and_unblocked(skel, light, sprite):
        # light is nearest to sprite
        if GroundRelation(Nearest, signature=(sprite, light)) in skel.relations:
            for e in skel.entities:
                if e.type == Bar and GroundRelation(OnTopOf, signature=(e, sprite)) in skel.relations:
                    return False
            return True 
        return False



    def any_agg(xs: List[List[bool]]) -> List[bool]:
        return [any(x) for x in zip(*xs)]

    role = RelationalRole(
        name="NearLights",
        pa=[LColor],
        const=RelationalConstraint(
                const=near_and_unblocked,
            ),
        agg=Aggregator(agg=any_agg),
    )
    # print("Testing relational role:")
    ground_attrs = gen_ground_attributes(skel)
    # print("Ground Attributes:")
    for obj, attrs in ground_attrs.items():
        print(f"{obj}:")
        for attr in attrs:
            print(f"  {attr}")
    # Get the ground attribute for S1's color from the ground_attrs dict
    
    vals = {g: True for e in ground_attrs.keys() for g in ground_attrs[e]}
    vals[GroundAttribute(type=LColor, owner=L2)] = False
    
    inst = RelationalInstance(skel=skel,
                              ground_attrs=ground_attrs,
                              vals=vals)
    
    # print("\nValues:")
    # for attr, val in inst.vals.items():
    #     print(f"{attr}: {val}") 

    # result = role(inst, S1)
    # print("\nRelational Role Result for S1.Color:")
    # print(result)  # Should aggregate LColor values of lights nearest to S1

    # Testing relational structural function
    def light_color_func(pa_V, u_V, pa_plus_V, u_plus_V):
        return u_V.get("U_l")  # just return the latent variable for LColor

    def sprite_shape_func(pa_V, u_V, pa_plus_V, u_plus_V):
        return u_V.get("U_ss") ^ u_V.get("U_s")

    def sprite_color_func(pa_V, u_V, pa_plus_V, u_plus_V):        
        return (u_V.get("U_sc") ^ u_V.get("U_s")) | pa_plus_V.get("NearLights")[0] if pa_plus_V.get("NearLights") else u_V.get("U_sc") ^ u_V.get("U_s")

   
    # evaluating light color for L1
    rsf_light_color = TemplateRSF(
        pa=[],
        u=[U_l],
        pa_roles=[],
        u_roles=[],
        f=light_color_func,
    )

    rsf_sprite_shape = TemplateRSF(
        pa=[],
        u=[U_ss, U_s],
        pa_roles=[],
        u_roles=[],
        f=sprite_shape_func,
    )

    rsf_sprite_color = TemplateRSF(
        pa=[],
        u=[U_sc, U_s],
        pa_roles=[role],
        u_roles=[],
        f=sprite_color_func,
    )

    # Example usage of RSFs
    # print("\nEvaluating RSFs:")
    # print("L1.LColor:", rsf_light_color(inst, L1))
    # print("S1.SShape:", rsf_sprite_shape(inst, S1))
    # print("S1.SColor:", rsf_sprite_color(inst, S3))


    # Construct TemplateRSCM
    U=[U_l, U_s, U_sc, U_ss]
    template_rscm = TemplateRSCM(
        schema=schema,
        U=U,
        F={
            LColor: rsf_light_color,
            Shape: rsf_sprite_shape,
            SColor: rsf_sprite_color,
        },
        PU={
            Light: BernoulliDistribution(u_names=[u.name for u in U if u.owner == Light], 
                                         sizes={u.name: 2 for u in U if u.owner == Light},
                                         p=0.7,
                                        ),       
            Bar: BernoulliDistribution(u_names=[u.name for u in U if u.owner == Bar], 
                                         sizes={u.name: 2 for u in U if u.owner == Bar},
                                         p=0.6,
                                        ),
            Sprite: BernoulliDistribution(u_names=[u.name for u in U if u.owner == Sprite], 
                                         sizes={u.name: 2 for u in U if u.owner == Sprite},
                                         p=0.2,
                                        ),
        },
    )

    # Get grounded RSCM
    grounded_rscm = template_rscm.get_ground_rscm(skel)
    print("\nGrounded RSCM:")
    print("Variables (V):", grounded_rscm.v)
    print("Functions (F):", grounded_rscm.f.keys())
    print("Exogenous Distribution (PU):", grounded_rscm.pu)

    # Sample from grounded RSCM
    sample = grounded_rscm.sample(n=5)
    



    
        

    
