import copy
from typing_extensions import Self

import numpy as np

import torch
from torch_geometric.data import Batch

from mas_sat.graph.base import BaseGraph
from mas_sat.utils.scatter import scatter_reduce

class VCGLGraph(BaseGraph):
    """
    Variable-Clause Graph w/ Learned Clauses
    """
    def __init__(self) -> None:
        super().__init__()

    @classmethod
    def from_observation(cls, observation: dict, dim: int, original=False) -> Self:
        data = cls()
        clause_name = "original_clause" if original else "learned_clause"

        # numbers
        n_literal = len(observation["literal_values"])
        n_variable = n_literal // 2
        n_clause = len(observation["clause_refs"])
        n_edge = len(observation["literal_indices"])

        # variable
        data["variable"].num_nodes = n_variable
        if original:
            data["variable"].x = torch.tile(torch.tensor([1, 0]), (n_variable, 1)).float()
            data["variable"].latent = torch.zeros((n_variable, dim))
            data["variable"].candidate_indices = torch.arange(n_literal).reshape(n_variable, 2)
            data["variable"].indices = torch.arange(n_variable).reshape(n_variable, 1)
        else:
            data["variable"].value = torch.tensor(observation["literal_values"].astype(float)).float().reshape(n_variable, 2)
            data["variable"].candidate = torch.tensor(observation["literal_candidates"]).reshape(n_variable, 2)

        # clause
        data[clause_name].num_nodes = n_clause
        data[clause_name].x = torch.tile(torch.tensor([0, 1]), (n_clause, 1)).float()
        data[clause_name].indices = torch.arange(n_clause).reshape(n_clause, 1)
        data[clause_name].latent = torch.zeros((n_clause, dim))

        # instance
        data["instance"].num_nodes = 1
        data["instance"].x = torch.zeros((1, 1))
        data["instance"].latent = torch.zeros((1, dim))

        # edge indices
        literal_indices = torch.tensor(observation["literal_indices"].astype(int)).flatten()
        clause_indices = torch.tensor(observation["clause_indices"].astype(int)).flatten()
        variable_indices = torch.divide(literal_indices, 2, rounding_mode="floor")
        polarity = torch.remainder(literal_indices, 2).reshape(n_edge, 1) # 0 for positive, 1 for negative
        
        # variable in clause
        attr_name = ("variable", "in", clause_name)
        # data[attr_name].num_edges = n_edge
        data[attr_name].edge_attr = torch.cat((polarity, 1-polarity), dim=1).float()
        data[attr_name].latent = torch.zeros((n_edge, dim))
        data[attr_name].edge_index = torch.stack((variable_indices, clause_indices))
        data[attr_name].polarity = polarity
        data[attr_name].indices = torch.arange(n_edge).reshape(n_edge, 1)
        
        # clause has variable
        attr_name = (clause_name, "has", "variable")
        # data[attr_name].num_edges = n_edge
        data[attr_name].edge_attr = torch.cat((polarity, 1-polarity), dim=1).float()
        data[attr_name].latent = torch.zeros((n_edge, dim))
        data[attr_name].edge_index = torch.stack((clause_indices, variable_indices))

        return data
    
    @classmethod
    def combine_graph(cls, original_graph: Self, learned_graph: Self) -> Self:
        # step 1: copy from original graph to learned graph
        ## variable
        learned_graph["variable"].x = original_graph["variable"].x
        learned_graph["variable"].candidate_indices = original_graph["variable"].candidate_indices
        learned_graph["variable"].indices = original_graph["variable"].indices
        learned_graph["variable"].latent = original_graph["variable"].latent

        ## original_clause
        learned_graph["original_clause"].x = original_graph["original_clause"].x
        learned_graph["original_clause"].indices = original_graph["original_clause"].indices
        learned_graph["original_clause"].latent = original_graph["original_clause"].latent

        ## instance
        learned_graph["instance"].x = original_graph["instance"].x
        learned_graph["instance"].latent = original_graph["instance"].latent

        ## variable in original_clause
        attr_name = ("variable", "in", "original_clause")
        learned_graph[attr_name].edge_attr = original_graph[attr_name].edge_attr
        learned_graph[attr_name].edge_index = original_graph[attr_name].edge_index
        learned_graph[attr_name].polarity = original_graph[attr_name].polarity
        learned_graph[attr_name].indices = original_graph[attr_name].indices
        learned_graph[attr_name].latent = original_graph[attr_name].latent

        ## original_clause has variable
        attr_name = ("original_clause", "has", "variable")
        learned_graph[attr_name].edge_attr = original_graph[attr_name].edge_attr
        learned_graph[attr_name].edge_index = original_graph[attr_name].edge_index
        learned_graph[attr_name].latent = original_graph[attr_name].latent

        # step 2: only keep candidate variables and unsatisfied clauses
        subset_dict = {}

        ## candidate variables
        variable_candidate = learned_graph["variable"].candidate[:,0]
        if not torch.all(variable_candidate):
            subset_dict["variable"] = variable_candidate

        ## unsatisfied clauses
        literal_values = learned_graph["variable"].value.flatten()
        literal_true = literal_values > 0.5
        if torch.any(literal_true):
            variable_indices, clause_indices = learned_graph["variable", "in", "original_clause"].edge_index
            polarity = learned_graph["variable", "in", "original_clause"].polarity.flatten()
            literal_indices = 2 * variable_indices + polarity
            literal_true = literal_true[literal_indices].int()
            n_clause = learned_graph["original_clause"].num_nodes
            clause_satisfied = scatter_reduce(literal_true, clause_indices, reduce="amax", dim=0, dim_size=n_clause).bool()
            subset_dict["original_clause"] = ~clause_satisfied

            variable_indices, clause_indices = learned_graph["variable", "in", "learned_clause"].edge_index
            polarity = learned_graph["variable", "in", "learned_clause"].polarity.flatten()
            literal_indices = 2 * variable_indices + polarity
            literal_true = literal_true[literal_indices].int()
            n_clause = learned_graph["learned_clause"].num_nodes
            clause_satisfied = scatter_reduce(literal_true, clause_indices, reduce="amax", dim=0, dim_size=n_clause).bool()
            subset_dict["learned_clause"] = ~clause_satisfied
        
        ## create subgraph
        if len(subset_dict) > 0:
            combined_graph = learned_graph.subgraph(subset_dict)
        else:
            combined_graph = learned_graph.clone()
        
        return combined_graph

    @classmethod
    def update_graph(cls, original_graph: Self, updated_graph: Self) -> Self:
        original_graph = original_graph.clone()
        
        # variable
        indices = updated_graph["variable"].indices.flatten()
        original_graph["variable"].latent[indices] = updated_graph["variable"].latent.clone()

        # original_clause
        indices = updated_graph["original_clause"].indices.flatten()
        original_graph["original_clause"].latent[indices] = updated_graph["original_clause"].latent.clone()

        # instance
        original_graph["instance"].latent = updated_graph["instance"].latent.clone()

        # variable in original_clause
        indices = updated_graph["variable", "in", "original_clause"].indices.flatten()
        original_graph["variable", "in", "original_clause"].latent[indices] = updated_graph["variable", "in", "original_clause"].latent.clone()
        original_graph["original_clause", "has", "variable"].latent[indices] = updated_graph["original_clause", "has", "variable"].latent.clone()

        return original_graph
    
    def get_candidate_num(self) -> int:
        return self["variable"].candidate.sum().item()
    
    def get_candidate_indices(self) -> torch.Tensor:
        indices = self["variable"].candidate_indices[self["variable"].candidate]
        return indices.flatten()
    
    def get_candidate_ptr(self) -> torch.Tensor:
        return self["variable"].ptr[:-1] * 2
    
    def get_candidate_batch(self) -> torch.Tensor:
        return torch.repeat_interleave(self["variable"].batch, 2)
    
    def get_score(self, assignment:torch.Tensor, hard: bool = False, clause_level: bool = False, eps: float = 1e-6):
        n_variable = self["variable"].num_nodes
        n_clause = self["original_clause"].num_nodes
        n_instance = len(self)
        variable_indices, clause_indices = self["variable", "in", "original_clause"].edge_index
        polarity = self["variable", "in", "original_clause"].polarity.flatten()
        literal_indices = variable_indices * 2 + polarity

        # assignment -> literal_scores
        assignment = assignment.reshape(n_variable, 2, -1)
        literal_scores = assignment.softmax(1).reshape(n_variable*2, -1)
        if hard:
            literal_scores = torch.where(literal_scores > 0.5, 1-eps, eps)

        # literal_scores -> clause_scores
        literal_scores = literal_scores[literal_indices, :]
        clause_scores = 1 - scatter_reduce(1-literal_scores, clause_indices, dim=0, reduce="prod", dim_size=n_clause)
        if clause_level:
            return clause_scores
        if hard:
            clause_scores = (clause_scores > 0.5).float()
        else:
            clause_scores_log = -clause_scores.log()
            clause_scores = clause_scores_log * (1-clause_scores)

        # clause_scores -> instance_scores
        if hasattr(self["original_clause"], "batch"):
            if hard:
                instance_scores = scatter_reduce(clause_scores, self["original_clause"].batch, dim=0, reduce="amin", dim_size=n_instance)
            else:
                instance_scores = scatter_reduce(clause_scores, self["original_clause"].batch, dim=0, reduce="mean", dim_size=n_instance)
        else:
            if hard:
                instance_scores = clause_scores.min(dim=0).values
            else:
                instance_scores = clause_scores.mean(dim=0)
        return instance_scores
    
    def is_trivial(self) -> bool:
        if self["original_clause"].num_nodes == 0:
            return True
        if self["variable", "in", "original_clause"].num_edges == 0:
            return True
        return False