module ConstructPlanningInstance

open System.Collections.Generic

open FsOmegaLib.LTL
open FsOmegaLib.SAT
open FsOmegaLib.DPA
open FsOmegaLib.Operations

open TransitionSystemLib.TransitionSystem

open Util
open SolverConfiguration
open HyperLTL
open PlanningInstance

type AutomatonSemantics = 
    | SAFE
    | REACH

let constructPlanningProblem (config : Configuration) (tsMap : Map<TraceVariable, TransitionSystem<string>>) (prop : HyperLTL<string>) (sem : AutomatonSemantics) = 
    let sw = System.Diagnostics.Stopwatch()
    sw.Start()

    tsMap
    |> Map.iter (fun pi ts ->
        match TransitionSystem.findError ts with 
        | None -> ()
        | Some msg -> 
            raise <| HyPlanException $"Found error in the system for %s{pi}: %s{msg}"
            
        if ts.InitialStates |> Set.count <> 1 then 
            raise <| HyPlanException $"Only supports TSs with a single initial state"
            
    )
    
    match HyperLTL.findError prop with 
    | None -> ()
    | Some msg -> raise <| HyPlanException $"Found error in the formula: %s{msg}"


    prop.LTLMatrix
    |> LTL.allAtoms
    |> Set.toList
    |> List.iter (fun (x, pi) ->
        if Map.containsKey pi tsMap |> not then 
            raise <| HyPlanException $"AP (%A{x}, %s{pi}) is used in the HyperLTL property but no system defined for %s{pi}"

        if List.contains x tsMap.[pi].APs |> not then
            raise <| HyPlanException $"AP (%A{x}, %s{pi}) is used in the HyperLTL property but AP %A{x} does not exists in the transition system for trace %s{pi}"
    )

    let traceVariables = HyperLTL.quantifiedTraceVariables prop

    // Convert the prefix to a block prefix
    let blockPrefix = HyperLTL.extractBlocks prop.QuantifierPrefix

    if List.length blockPrefix > 2 then 
        raise <| HyPlanException "Only applicable to \\forall^*\\exists^* properties"

    let uTraceVariables, eTraceVariables = 
        if fst blockPrefix.[0] = FORALL then 
            snd blockPrefix.[0], if List.length blockPrefix = 1 then [] else snd blockPrefix.[1]
        else 
            if List.length blockPrefix <> 1 then 
                raise <| HyPlanException "Only applicable to \\forall^*\\exists^* properties"
            [], snd blockPrefix.[0]

    let dpa =
        match FsOmegaLib.Operations.LTLConversion.convertLTLtoDPA config.Debug config.SolverConfig.MainPath config.SolverConfig.Ltl2tgbaPath prop.LTLMatrix with 
        | Success aut -> aut 
        | Fail err -> 
            config.Logger.LogN err.DebugInfo
            raise <| HyPlanException err.Info

    config.Logger.LogN $"Converted to DPA with %i{dpa.States.Count} in %i{sw.ElapsedMilliseconds}ms (~=%.2f{double(sw.ElapsedMilliseconds) / 1000.0}s)"
    sw.Restart()

    let nonEmptyStates, universalStates = AutomataUtil.findNonEmptyAndUniversalStates dpa


    let apLookupList = 
        dpa.APs
        |> List.map (fun (ap, pi) -> 
            pi, List.findIndex ((=) ap) tsMap.[pi].APs
            )

    // We only encode distinct systems in our encoding to make the encoding smaller
    let distinctSystems = 
        tsMap
        |> Map.values
        |> Seq.distinct
        |> Seq.toList

    
    let tsIndexMap = 
        tsMap
        |> Map.map (fun _ ts -> 
            List.findIndex ((=) ts) distinctSystems
            )

    let tsDegrees = 
        distinctSystems
        |> List.map (fun ts -> 
            ts.Edges
            |> Map.values
            |> Seq.map Set.count
            |> Seq.max
            )

    // ============== PDDL Types ==============

    let systemStateTypes = 
        [0..distinctSystems.Length - 1]
        |> List.map (fun i -> "system-state-" + string i)

    let moveTypes = 
        [0..distinctSystems.Length - 1]
        |> List.map (fun i -> "move-" + string i)

    // ============== PDDL Constants ==============
    let moveConstants = 
        distinctSystems
        |> List.mapi (fun i _ -> 
            let maxDegree = tsDegrees.[i]

            [0..maxDegree-1]
            |> List.map (fun s -> "m-o-" + string s + "-" + string i, moveTypes.[i])
            )

    // ============== Predicates ==============


    let moveSelectionStagePredicates = 
        traceVariables
        |> List.map (fun pi -> 
            pi, {
                PredicateDefinition.Name = "move-select-stage-" + pi
                Parameters = []
            }
            )
        |> Map.ofList

    let moveApplyStagePredicates = 
        traceVariables
        |> List.map (fun pi -> 
            pi, {
                PredicateDefinition.Name = "move-apply-stage-" + pi
                Parameters = []
            }
            )
        |> Map.ofList

    let updateStagePredicate = 
        {
            PredicateDefinition.Name = "update-stage"
            Parameters = []
        }

    // A predicate that can non-deterministically be set at every step, used to ensure that for SAFETY objected a strong cyclic plan remains the ability to win 
    let winPredicate = 
        {
            PredicateDefinition.Name = "win"
            Parameters = []
        }

    let automatonStatePredicates = 
        dpa.States
        |> Seq.toList
        |> List.map (fun q -> 
            q, {
                PredicateDefinition.Name = "automaton-at-" + string q
                Parameters = []
            }
        )
        |> Map.ofList

    let atPredicatesMap = 
        traceVariables
        |> List.map (fun pi -> 
            pi, {
                PredicateDefinition.Name = pi + "-at"
                Parameters = [("?s", systemStateTypes.[tsIndexMap.[pi]])]
            }
            )
        |> Map.ofList

    
    // Predicates that give the AP-labelling of the current state, used to encode the transitions of the automaton symbolically
    let apEvaluationPredicates = 
        dpa.APs
        |> List.map (fun (ap, pi) -> 
            {
                PredicateDefinition.Name = "ap-" + ap + "-at-" + pi
                Parameters = [("?s", systemStateTypes.[tsIndexMap.[pi]])]
            }
        )

    let apCurrentPredicates = 
        dpa.APs
        |> List.map (fun (ap, pi) -> 
            {
                PredicateDefinition.Name = "ap-" + ap + "-" + pi
                Parameters = []
            }
        )

    let movePredicatesMap = 
        traceVariables
        |> List.map (fun pi -> 
            pi, {
                PredicateDefinition.Name = "move-" + pi
                Parameters = [("?m", moveTypes.[tsIndexMap.[pi]])]
            }
            )
        |> Map.ofList

    let systemMoveEdgePredicates = 
        distinctSystems
        |> List.mapi (fun i _ -> 
            {
                PredicateDefinition.Name = "system-move-edge-" + string i
                Parameters = 
                    [
                        ("?s", systemStateTypes.[i]);
                        ("?m", moveTypes.[i]);
                        ("?ss", systemStateTypes.[i])
                    ]
            } 
            )

    // ============== PDDL Actions ==============

    let automatonMoveActions = 
        dpa.States
        |> Set.toList
        |> List.map (fun q -> 
            let sucs = dpa.Edges.[q]

            sucs 
            |> List.map (fun (g, qq) -> 
                let remappedGuard = 
                    g 
                    |> List.map (fun clause -> 
                        clause 
                        |> List.map (fun lit -> 
                            let i = Literal.getValue lit
                            let f = BooleanFormula.Atom(apCurrentPredicates.[i].Name, [])
                            match lit with 
                            | PL _ -> f 
                            | NL _ -> BooleanFormula.Not f
                            )
                        |> function 
                            | [] -> BooleanFormula.And []
                            | [x] -> x
                            | xs -> BooleanFormula.And xs
                        )
                        |> function 
                        | [] -> BooleanFormula.Or []
                        | [x] -> x
                        | xs -> BooleanFormula.Or xs

                {
                    ActionDefinition.Name = "automaton-update-" + string q + "-to-" + string qq 
                    Parameters = []
                    Precondition = 
                        [BooleanFormula.Atom(updateStagePredicate.Name, [])]
                        @
                        // Currently in the hard coded automaton state
                        [BooleanFormula.Atom(automatonStatePredicates.[q].Name, [])]
                        @
                        // The guard of the unique transition from q to qq holds
                        [remappedGuard]
                        |> BooleanFormula.And
                    Effect = 
                        // Move to the next stage
                        [
                            BooleanFormula.Not (BooleanFormula.Atom(updateStagePredicate.Name, []));
                            BooleanFormula.Atom(moveSelectionStagePredicates.[traceVariables.[0]].Name, []) // Move to the selection of a move for the first trace
                        ]
                        @
                        // Update the automaton state, only if the states differ
                        (if q = qq then [] else [BooleanFormula.Atom(automatonStatePredicates.[qq].Name, []); BooleanFormula.Not (BooleanFormula.Atom(automatonStatePredicates.[q].Name, []))])
                        // Non-deterministically set the winning predicate, to ensure safe plans always have the chance to win
                        @ 
                        (
                            if sem = SAFE then
                                if nonEmptyStates.Contains q then 
                                    // Only when the current state is non-empty, we allow to set the winningPredicate
                                    BooleanFormula.Oneof [
                                        BooleanFormula.Atom (winPredicate.Name, []);
                                        BooleanFormula.Not (BooleanFormula.Atom (winPredicate.Name, []))
                                    ]
                                    |> List.singleton
                                else 
                                    []
                            else 
                                []
                        )
                        |> BooleanFormula.And
                }
                )
            )
        |> List.concat
    
    // Each action hard codes the automaton states, so no automaton objects or parameters are needed
    let selectMoveActions = 
        [0..traceVariables.Length - 1]
        |> List.map (fun i -> 
            let pi = traceVariables.[i]

            if List.contains pi uTraceVariables then 
                // Universal variable, we pick the move non-deterministically
                {
                    ActionDefinition.Name = "select-move-" + pi
                    Parameters = []
                    Precondition = BooleanFormula.Atom(moveSelectionStagePredicates.[pi].Name, [])
                    Effect = 
                        // Move to the next stage
                        [
                            BooleanFormula.Not (BooleanFormula.Atom(moveSelectionStagePredicates.[pi].Name, []));
                            // Either select a move for the next trace of move to the application phase of the first trace
                            if i < traceVariables.Length - 1 then BooleanFormula.Atom(moveSelectionStagePredicates.[traceVariables.[i+1]].Name, []) else BooleanFormula.Atom(moveApplyStagePredicates.[traceVariables.[0]].Name, [])
                        ]
                        @
                        // Choose a move for pi non-deterministically
                        [
                            moveConstants.[tsIndexMap.[pi]]
                            |> List.map fst // Get the name of each move 
                            |> List.map (fun x -> BooleanFormula.Atom(movePredicatesMap.[pi].Name, [x]))
                            |> BooleanFormula.Oneof
                        ]
                        |> BooleanFormula.And
                }

            else 
                assert(List.contains pi eTraceVariables)

                {
                    ActionDefinition.Name = "select-move-" + pi
                    Parameters = [("?m-" + pi, moveTypes.[tsIndexMap.[pi]])]
                    Precondition = BooleanFormula.Atom(moveSelectionStagePredicates.[pi].Name, [])
                    Effect = 
                        // Move to the next stage
                        [
                            BooleanFormula.Not (BooleanFormula.Atom(moveSelectionStagePredicates.[pi].Name, []));
                            // Either select a move for the next trace of move to the application phase of the first trace
                            if i < traceVariables.Length - 1 then BooleanFormula.Atom(moveSelectionStagePredicates.[traceVariables.[i+1]].Name, []) else BooleanFormula.Atom(moveApplyStagePredicates.[traceVariables.[0]].Name, [])
                        ]
                        @
                        // Set the move for pi to the one selected by the parameter
                        [
                            BooleanFormula.Atom(movePredicatesMap.[pi].Name, ["?m-" + pi])
                        ]
                        |> BooleanFormula.And
                }

            )

    
    let applyMoveActions = 
        [0..traceVariables.Length - 1]
        |> List.map (fun i -> 
            let pi = traceVariables.[i]

            {
                ActionDefinition.Name = "apply-move-" + pi
                Parameters = [("?s-" + pi, systemStateTypes.[tsIndexMap.[pi]]); ("?m-" + pi, moveTypes.[tsIndexMap.[pi]]); ("?ss-" + pi, systemStateTypes.[tsIndexMap.[pi]])]
                Precondition = 
                    [
                        BooleanFormula.Atom(moveApplyStagePredicates.[pi].Name, []);
                        BooleanFormula.Atom(atPredicatesMap.[pi].Name, ["?s-" + pi]);
                        BooleanFormula.Atom(movePredicatesMap.[pi].Name, ["?m-" + pi]);
                        BooleanFormula.Atom(systemMoveEdgePredicates.[tsIndexMap.[pi]].Name, ["?s-" + pi; "?m-" + pi; "?ss-" + pi])
                    ]
                    |> BooleanFormula.And
                Effect = 
                    // Move to the next stage
                    [
                        BooleanFormula.Not (BooleanFormula.Atom(moveApplyStagePredicates.[pi].Name, []));
                        // Either select a move for the next trace of move to the application phase of the first trace
                        if i < traceVariables.Length - 1 then BooleanFormula.Atom(moveApplyStagePredicates.[traceVariables.[i+1]].Name, []) else BooleanFormula.Atom(updateStagePredicate.Name, [])
                    ]
                    @
                    // Update the successor state for pi
                    [
                        BooleanFormula.Atom(atPredicatesMap.[pi].Name, ["?ss-" + pi]); 
                        BooleanFormula.When (BooleanFormula.Not (BooleanFormula.Atom("=", ["?s-" + pi; "?ss-" + pi])), BooleanFormula.Not (BooleanFormula.Atom(atPredicatesMap.[pi].Name, ["?s-" + pi])))
                    ]
                    @
                    // Reset the move predicate for pi
                    [
                        BooleanFormula.Not (BooleanFormula.Atom(movePredicatesMap.[pi].Name, ["?m-" + pi]))
                    ]
                    @
                    // Update the AP Evaluations for pi
                    (
                        dpa.APs
                        |> List.mapi (fun i (_, pii) -> 
                            if pii = pi then 
                                // We set the current predicates iff they hold on the next state based on the fact predicates
                                [
                                    BooleanFormula.When (
                                        BooleanFormula.Atom(apEvaluationPredicates.[i].Name, ["?ss-" + pi]),
                                        BooleanFormula.Atom(apCurrentPredicates.[i].Name, [])
                                    );
                                    BooleanFormula.When (
                                        BooleanFormula.Not (BooleanFormula.Atom(apEvaluationPredicates.[i].Name, ["?ss-" + pi])),
                                        BooleanFormula.Not (BooleanFormula.Atom(apCurrentPredicates.[i].Name, []))
                                    )
                                ]
                                |> BooleanFormula.And
                                |> Some
                            else
                                None
                            )
                        |> List.choose id
                    )   
                    |> BooleanFormula.And
            }
        )
            

    let dom = 
        {
            PlanningDomain.Name = "dom"
            Types = systemStateTypes @ moveTypes
            Constants = 
                (moveConstants |> List.concat)
            Predicates = 
                (if sem = SAFE then [winPredicate] else []) // We only need the winning predicate for safety properties
                @
                (moveSelectionStagePredicates |> Map.values |> Seq.toList)
                @
                (moveApplyStagePredicates |> Map.values |> Seq.toList)
                @
                [updateStagePredicate]
                @
                (atPredicatesMap |> Map.values |> Seq.toList)
                @ 
                (movePredicatesMap |> Map.values |> Seq.toList)
                @
                apEvaluationPredicates
                @
                apCurrentPredicates
                @ 
                (automatonStatePredicates |> Map.values |> Seq.toList)
                @ 
                systemMoveEdgePredicates
            Actions = 
                automatonMoveActions
                @
                selectMoveActions
                @
                applyMoveActions
        }

    // ============================== The Planning Problem ==============================

    let systemStateObjects = 
        distinctSystems
        |> List.mapi (fun i ts -> 
            ts.States
            |> Seq.toList
            |> List.map (fun s -> s, ("s-o-" + string s + "-" + string i, systemStateTypes.[i]))
            )
        |> List.map Map.ofList

    

    // =============== Initial Facts ===============

    let automatonInitialFact = (automatonStatePredicates.[dpa.InitialState].Name, [])
    
    let systemInitialFacts = 
        traceVariables
        |> List.map (fun pi -> 
            let tsIndex = tsIndexMap.[pi]
            let initState = tsMap.[pi].InitialStates |> Seq.head

            (atPredicatesMap.[pi].Name, [systemStateObjects.[tsIndex][initState] |> fst])
            )

    // We start of the initial update phase
    let inMoveSelectionStageInitialFact = (updateStagePredicate.Name, [])

    let apEvaluationPredicatesInitialFacts = 
        dpa.APs
        |> List.mapi (fun i (_, pi) ->
            let tsIndex = tsIndexMap.[pi] 
            
            let pred = apEvaluationPredicates.[i].Name

            let _, apIndex = apLookupList.[i]

            tsMap.[pi].States
            |> Seq.toList
            |> List.filter (fun s -> 
                Set.contains apIndex tsMap.[pi].ApEval.[s]
                )
            |> List.map (fun s -> 
                (pred, [systemStateObjects.[tsIndex].[s] |> fst])
                )
            )
        |> List.concat

    let apCurrentPredicatesInitialFacts = 
        dpa.APs
        |> List.mapi (fun i (_, pi: TraceVariable) ->
            let _, apIndex = apLookupList.[i]
            let init = tsMap.[pi].InitialStates |> Seq.head
            let pred = apEvaluationPredicates.[i].Name

            if Set.contains apIndex tsMap.[pi].ApEval.[init] then 
                // The AP holds in the initil state
                (pred, [])
                |> Some
            else 
                None
        )
        |> List.choose id

    let systemMoveEdgeInitialFacts = 
        distinctSystems
        |> List.mapi (fun i ts -> 
            ts.States
            |> Seq.toList
            |> List.map (fun s -> 
                let sucs = ts.Edges.[s] |> Set.toList

                [0..tsDegrees.[i] - 1]
                |> List.map (fun moveIndex -> 
                    let ss = 
                        if moveIndex < List.length sucs then 
                            sucs.[moveIndex]
                        else 
                            // All move indices that are too large are mapped to the first successor
                            sucs.[0]
                    
                    let parameters = 
                        [
                            systemStateObjects.[i].[s] |> fst;
                            moveConstants.[i].[moveIndex] |> fst;
                            systemStateObjects.[i].[ss] |> fst
                        ]
                        
                    (systemMoveEdgePredicates.[i].Name, parameters)
                    )
                )
            |> List.concat
            )
        |> List.concat

    let goalFormula = 
        match sem with 
        | SAFE -> 
            // The goal is that the winning predicate is set 
            BooleanFormula.Atom(winPredicate.Name, [])
        | REACH -> 
            // The goal is to reach a winning (universal state)
            universalStates
            |> Set.toList
            |> List.map (fun q -> 
                BooleanFormula.Atom(automatonStatePredicates.[q].Name, [])
                )
            |> BooleanFormula.Or
            

    let prob = 
        {
            PlanningProblem.Name = "prob"
            Domain = "dom"
            Objects = systemStateObjects |> List.map Map.values |> List.map Seq.toList |> List.concat
            Init = 
                [inMoveSelectionStageInitialFact]
                @
                [automatonInitialFact]
                @
                systemInitialFacts
                @
                apEvaluationPredicatesInitialFacts
                @
                apCurrentPredicatesInitialFacts
                @
                systemMoveEdgeInitialFacts
            Goal = goalFormula
        }

    config.Logger.LogN $"Created planning instance in %i{sw.ElapsedMilliseconds}ms (~=%.2f{double(sw.ElapsedMilliseconds) / 1000.0}s)"

    dom, prob
        