import spot
import sys
import numpy as np
from typing import Dict, Any, Tuple, Set, Optional, List

class LTLFormulaToolbox:
    """
    Utility class for LTL formula formatting, cleaning, and validation.
    """
    
    @staticmethod
    def _clean_single_line(formula: str) -> str:
        """Basic cleaning: strip, remove punctuation, balance parens."""
        if not formula: return ""
        
        f = formula.strip().rstrip(".,;:!?&'`\\ ").strip()
        
        open_count = f.count('(')
        close_count = f.count(')')
        
        if open_count > close_count:
            f += ')' * (open_count - close_count)
        elif close_count > open_count:
            diff = close_count - open_count
            while diff > 0 and f.endswith(')'):
                f = f[:-1]
                diff -= 1
            if diff > 0:
                f = '(' * diff + f
        return f

    @staticmethod
    def preprocess_positive(formula_str: str) -> str:
        """Pos LTL: Flatten to single line."""
        if not formula_str: return ""
        flat_formula = formula_str.replace('\n', ' ')
        return LTLFormulaToolbox._clean_single_line(flat_formula)

    @staticmethod
    def preprocess_negative(formula_str: str) -> List[str]:
        """Neg LTL: Split into lines."""
        if not formula_str: return []
        lines = formula_str.split('\n')
        cleaned_lines = []
        for line in lines:
            cl = LTLFormulaToolbox._clean_single_line(line)
            if cl: cleaned_lines.append(cl)
        return cleaned_lines

    @staticmethod
    def validate_formula(formula: str, known_aps: Set[str], strict_aps: bool = False) -> Optional[str]:
        """
        Validates a SINGLE cleaned formula string.
        
        Args:
            formula: Cleaned LTL formula string.
            known_aps: Set of valid AP strings.
            strict_aps: 
                If True, returns error if formula uses APs not in known_aps.
                If False, only returns error on Syntax Error.
                
        Returns:
            None if valid.
            Error message string if invalid.
        """
        if not formula: return None
        
        try:
            f = spot.formula(formula)
            
            # Check APs if strict
            if strict_aps:
                used_aps = set([ap.to_str() for ap in spot.atomic_prop_collect(f)])
                unknown = used_aps - known_aps
                if unknown:
                    return f"Unknown APs used: {unknown}. Allowed: {list(known_aps)}..."
                    
        except Exception as e:
            return f"Syntax Error: {str(e)}"
            
        return None


class LTLDualTrackEngine:
    def __init__(self, 
                 negative_formula: str, 
                 positive_formula: str, 
                 known_aps: list,
                 negative_reward: float = -1.0,
                 positive_reward: float = 1.0,
                 trend_reward: float = 0.0,
                 verbose: bool = False,
                 strict_aps: bool = False):
        
        self.known_aps = set(known_aps) if known_aps else set()
        self.negative_reward_val = negative_reward
        self.positive_reward_val = positive_reward
        self.trend_reward_val = trend_reward
        self.verbose = verbose
        self.strict_aps = strict_aps
        
        self.neg_monitors = []
        self.pos_automata = []
        self.valid = True
        self.error_msg = ""
        self.failed_neg_count = 0 
        self.failed_pos_count = 0

        # 1. Preprocess
        clean_neg_lines = LTLFormulaToolbox.preprocess_negative(negative_formula)
        clean_pos_formula = LTLFormulaToolbox.preprocess_positive(positive_formula)

        # 2. Compile
        try:
            self._compile_negative(clean_neg_lines)
            self._compile_positive(clean_pos_formula)
        except Exception as e:
            self.valid = False
            self.error_msg = str(e)
            if self.verbose:
                print(f"[LTL Engine Error] Critical Compilation Failure: {e}")

    def get_compilation_stats(self) -> Dict[str, Any]:
        return {
            "failed_neg_count": self.failed_neg_count,
            "failed_pos_count": self.failed_pos_count,
            "negative_formulas_compiled": len(self.neg_monitors),
            "positive_formulas_compiled": len(self.pos_automata),
            "engine_valid": self.valid,
            "engine_error_msg": self.error_msg,
        }

    def _compile_negative(self, clean_lines: List[str]):
        for line in clean_lines:
            # Validate Syntax & APs (based on strictness)
            err = LTLFormulaToolbox.validate_formula(line, self.known_aps, self.strict_aps)
            if err:
                if self.verbose:
                    print(f"[LTL Warning] Skipping invalid negative formula line '{line}': {err}")
                self.failed_neg_count += 1
                continue

            try:
                # Re-parse for compilation (Toolbox checked it, so this should pass unless OOM etc)
                aut = spot.translate(line, "monitor", "det")
                f = spot.formula(line)
                req_aps = set([ap.to_str() for ap in spot.atomic_prop_collect(f)])
                
                self.neg_monitors.append({
                    'aut': aut,
                    'bdd': aut.get_dict(),
                    'state': aut.get_init_state_number(),
                    'init_state': aut.get_init_state_number(),
                    'aps': req_aps
                })
            except Exception as e:
                if self.verbose:
                    print(f"[LTL Error] Failed to compile negative line '{line}': {e}")
                self.failed_neg_count += 1

    def _compile_positive(self, clean_formula: str):
        if not clean_formula: return
        
        # Validate
        err = LTLFormulaToolbox.validate_formula(clean_formula, self.known_aps, self.strict_aps)
        if err:
            if self.verbose:
                print(f"[LTL Warning] Skipping invalid positive formula '{clean_formula}': {err}")
            self.failed_pos_count += 1
            return

        try:
            # Compile
            aut = spot.translate(clean_formula, "BA", "det", "complete")
            f = spot.formula(clean_formula)
            req_aps = set([ap.to_str() for ap in spot.atomic_prop_collect(f)])
            
            self.pos_automata.append({
                'aut': aut,
                'bdd': aut.get_dict(),
                'state': aut.get_init_state_number(),
                'init_state': aut.get_init_state_number(),
                'aps': req_aps,
                'visited': {aut.get_init_state_number()}
            })
        except Exception as e:
            if self.verbose:
                print(f"[LTL Error] Failed to compile positive formula '{clean_formula}': {e}")
            self.failed_pos_count += 1

    def reset(self):
        for monitor in self.neg_monitors:
            monitor['state'] = monitor['init_state']
        for entry in self.pos_automata:
            entry['state'] = entry['init_state']
            entry['visited'] = {entry['init_state']}

    def _get_bdd_input(self, current_aps, required_aps, bdd_dict, aut):
        filtered_aps = current_aps.intersection(required_aps)
        conjunction_parts = []
        for ap in required_aps:
            if ap in filtered_aps:
                conjunction_parts.append(ap)
            else:
                conjunction_parts.append(f"!{ap}")
        
        formula_str = " & ".join(conjunction_parts) if conjunction_parts else "1"
        return spot.formula_to_bdd(formula_str, bdd_dict, aut)

    def step(self, current_aps: set) -> tuple[float, float]:
        if not self.valid: return 0.0, 0.0

        neg_r, pos_r = 0.0, 0.0

        # Negative Track
        for monitor in self.neg_monitors:
            input_bdd = self._get_bdd_input(current_aps, monitor['aps'], monitor['bdd'], monitor['aut'])
            next_state = None
            for edge in monitor['aut'].out(monitor['state']):
                if spot.buddy.bdd_implies(input_bdd, edge.cond):
                    next_state = edge.dst
                    break
            
            if next_state is not None:
                monitor['state'] = next_state
            else:
                # Triggered Negative Constraint
                neg_r += self.negative_reward_val

        # Positive Track
        for entry in self.pos_automata:
            input_bdd = self._get_bdd_input(current_aps, entry['aps'], entry['bdd'], entry['aut'])
            next_state = None
            for edge in entry['aut'].out(entry['state']):
                if spot.buddy.bdd_implies(input_bdd, edge.cond):
                    next_state = edge.dst
                    break
            
            if next_state is not None:
                if next_state != entry['state']:
                    if next_state not in entry['visited']:
                        # Triggered Positive Milestone
                        pos_r += self.positive_reward_val
                        entry['visited'].add(next_state)
                    entry['state'] = next_state
        
        return neg_r, pos_r

    def process_trajectory(self, aps_stream: list[set]) -> list[tuple[float, float, dict]]:
        """
        Processes the trajectory and assigns rewards based on LTL satisfaction.
        
        Logic:
        1. Penalty List: Mark index 't' if negative constraint violated.
        2. Milestone List: Mark index 't-1' if positive automaton advances.
        3. Trend List: Mark indices '0' to 't-2' if positive automaton advances.
        
        Returns:
            List of (neg_reward, pos_reward, info_dict) tuples.
            pos_reward combines Milestone and Trend rewards.
            info_dict contains boolean flags: 'is_milestone', 'is_trend', 'is_violation'.
        """
        if not self.valid or not aps_stream:
            empty_info = {"is_milestone": False, "is_trend": False, "is_violation": False}
            return [(0.0, 0.0, empty_info)] * len(aps_stream)
        
        self.reset()
        traj_len = len(aps_stream)
        
        # Binary masks (counters)
        penalty_mask = np.zeros(traj_len, dtype=int)
        milestone_mask = np.zeros(traj_len, dtype=int)
        trend_mask = np.zeros(traj_len, dtype=int)
        
        for i, current_aps in enumerate(aps_stream):
            neg_signal, pos_signal = self.step(current_aps)
            
            # Detect Negative Trigger (assuming negative_reward_val is negative)
            # If signal < 0, a violation occurred.
            if neg_signal != 0:
                penalty_mask[i] = 1
                
            # Detect Positive Trigger
            # If signal > 0, a milestone was reached.
            if pos_signal != 0:
                # Milestone at t-1 (the action that caused it)
                m_idx = max(0, i - 1)
                milestone_mask[m_idx] = 1
                
                # Trend at 0 to t-2 (history leading up to action)
                # Slicing in Python is exclusive at the end: [:m_idx] covers 0 to m_idx-1.
                if m_idx > 0:
                    trend_mask[:m_idx] = 1
                    
        # Synthesize Rewards
        aligned_results = []
        for i in range(traj_len):
            # Penalties: Use self.negative_reward_val (usually -1.0)
            n_r = penalty_mask[i] * self.negative_reward_val
            
            # Positive: Milestone Priority over Trend
            if milestone_mask[i]:
                p_r = self.positive_reward_val
            elif trend_mask[i]:
                p_r = self.trend_reward_val
            else:
                p_r = 0.0
            
            info = {
                "is_milestone": bool(milestone_mask[i]),
                "is_trend": bool(trend_mask[i]),
                "is_violation": bool(penalty_mask[i])
            }
            aligned_results.append((n_r, p_r, info))
                
        return aligned_results