module ConstructParityGame 

open System.Collections.Generic
open System.IO

open FsOmegaLib.LTL
open FsOmegaLib.SAT
open FsOmegaLib.DPA
open FsOmegaLib.Operations

open TransitionSystemLib.TransitionSystem

open Util
open SolverConfiguration
open HyperLTL
open PlanningInstance


type private Counter(init : int) =
    let mutable a = init

    new () = Counter(0)

    member this.Reset() =   
        a <- 0

    member this.Get = a

    member this.Inc() =
        a <- a + 1

    member this.Inc(x) =
        a <- a + x
    
    member this.Dec() =
        a <- a - 1

    member this.Dec(x) =
        a <- a - x


type ParityGamePlayer = 
    | PlayerZero 
    | PlayerOne

module ParityGamePlayer = 
    let asString p = 
        match p with 
        | PlayerZero -> "0"
        | PlayerOne -> "1"

    let flip p = 
        match p with 
        | PlayerZero -> PlayerOne
        | PlayerOne -> PlayerZero


type ParityGame<'T when 'T : comparison> = 
    {
        Properties: Map<'T, Set<'T> * ParityGamePlayer * int>
    }

module ParityGame = 
    let convertParityGameToString (pg : ParityGame<'T>) = 

        let a, _ = 
            pg.Properties.Keys
            |> Seq.toList
            |> List.mapi (fun i x -> 
                (x, i), (i, x))
            |> List.unzip

        let d = Map.ofSeq a 

        let sw = new StringWriter()

        sw.WriteLine ("parity " + string(d.Count) + ";")

        for s in pg.Properties.Keys do 
            let sucs, player, color = pg.Properties.[s]
            let id = d.[s]

            let sucString = 
                sucs 
                |> Seq.toList
                |> List.map (fun x -> string(d.[x]))
                |> Util.combineStringsWithSeparator ","

            sw.WriteLine(string(id) + " " + string(color) + " " + ParityGamePlayer.asString player + " " + sucString + " " + "\"\"" + ";")

        sw.ToString()


// ============================================================================================================

type GameState = 
    {
        SystemStates : Map<TraceVariable, int>
        MainDpaState : int
    }

type UniversalMove = 
    {
        NextUniversalStates : Map<TraceVariable, int>
    }

/// Type that represents the decision made by the \exists player
type ExistentialMove = 
    {
        NextExistentialStates : Map<TraceVariable, int>
    }

type ParityGameState =  
    | ForallStage of GameState
    | ExistentialStage of GameState * UniversalMove
    | UpdateStage of GameState * UniversalMove * ExistentialMove

let private compileToParityGame<'L when 'L : comparison> (systemMap :  Map<TraceVariable, TransitionSystem<'L>>) (universalTraceVariables : Set<TraceVariable>) (aut : DPA<int, 'L * TraceVariable>) = 
    
    let allTraceVariables = 
        aut.APs
        |> List.map snd 
        |> set

    allTraceVariables
    |> Set.iter (fun pi -> 
        if Map.containsKey pi systemMap |> not then 
            raise <| HyPlanException $"Trace variable %s{pi} is used in the formula but no system for %s{pi} is given"
        )

    // We only consider the variables that are actually used in the formula
    let universalTraceVariables = Set.intersect universalTraceVariables allTraceVariables

    // Compute the DPA that includes all prophecies
    let mainDpa = aut

    let initalState = 
        {
            GameState.SystemStates =   
                systemMap 
                |> Map.map (fun _ x -> x.InitialStates |> Seq.head)
            MainDpaState = mainDpa.InitialState
        }
        |> ForallStage

    let visited = new HashSet<_>(Seq.singleton initalState)

    let queue = new Queue<_>(Seq.singleton initalState)

    let propertyDict = new Dictionary<_,_>()

    while queue.Count <> 0 do 
        //printfn "%i" queue.Count
        let s = queue.Dequeue() 
        let sucs, _, p, c = 
            match s with 
            | ForallStage (gameState) -> 
                let sucs = 
                    gameState.SystemStates
                    |> Map.filter (fun pi _ -> Set.contains pi universalTraceVariables)
                    |> Map.map (fun pi s -> 
                        systemMap.[pi].Edges.[s]
                        )
                    |> Util.cartesianProductMap
                    |> Seq.toList
                    |> List.map (fun x -> ExistentialStage(gameState, {UniversalMove.NextUniversalStates = x}))

                let info = $"ForallStage: %A{gameState}"
                
                // This state is controlled by the adversary
                sucs, info, PlayerOne, mainDpa.Color.[gameState.MainDpaState]
            
            | ExistentialStage (gameState, universalMove) -> 
                let possibleNextExistentialStates = 
                    gameState.SystemStates
                    |> Map.filter (fun pi _ -> Set.contains pi universalTraceVariables|> not)
                    |> Map.map (fun pi s -> 
                        systemMap.[pi].Edges.[s]
                        )
                    |> Util.cartesianProductMap
                    |> Seq.toList

                
                let sucs = 
                    possibleNextExistentialStates
                    |> List.map (fun nextExistentialStates -> 
                        UpdateStage (gameState, universalMove, {ExistentialMove.NextExistentialStates = nextExistentialStates})
                        )

                let info = $"ExistentialStage: %A{gameState}, %A{universalMove}"
                
                // This state is controlled by the system player
                sucs, info, PlayerZero, mainDpa.Color.[gameState.MainDpaState]
            
            | UpdateStage (gameState, universalMove, existentialMove) -> 
                let nextMainDpaState = 
                    mainDpa.Edges.[gameState.MainDpaState]
                    |> List.find (fun (guard, _) -> 
                        guard
                        |> DNF.eval (fun i -> 
                            let (ap, pi) = mainDpa.APs.[i]
                            let index = systemMap.[pi].APs |> List.findIndex ((=) ap)
                            Set.contains index (systemMap.[pi].ApEval.[gameState.SystemStates.[pi]])
                        ) 
                    )
                    |> snd

                let nextGameState = 
                    {
                        GameState.SystemStates = 
                            Util.mergeMaps universalMove.NextUniversalStates existentialMove.NextExistentialStates
                        MainDpaState = nextMainDpaState
                    }

                let sucs = ForallStage nextGameState |> List.singleton

                let info = $"UpdateStage: %A{gameState}, %A{universalMove}, %A{existentialMove}"

                sucs, info, PlayerZero, mainDpa.Color.[gameState.MainDpaState] // Color and player does not really matter here

        propertyDict.Add(s, (set sucs, p, c))
        
        // Add to the queue if this is a new state
        for s' in sucs do 
            if visited.Contains s' |> not then 
                visited.Add s' |> ignore 
                queue.Enqueue s'
         
    { Properties = Util.dictToMap propertyDict }, initalState


let constructParityGame (config: Configuration) (systemMap :  Map<TraceVariable, TransitionSystem<string>>) (prop : HyperLTL<string>) = 
    // Convert the prefix to a block prefix
    let blockPrefix = HyperLTL.extractBlocks prop.QuantifierPrefix

    if List.length blockPrefix > 2 then 
        raise <| HyPlanException "Only applicable to \\forall^*\\exists^* properties"

    let uTraceVariables = 
        if fst blockPrefix.[0] = FORALL then 
            snd blockPrefix.[0]
        else 
            if List.length blockPrefix <> 1 then 
                raise <| HyPlanException "Only applicable to \\forall^*\\exists^* properties"
            []

    let dpa =
        match FsOmegaLib.Operations.LTLConversion.convertLTLtoDPA config.Debug config.SolverConfig.MainPath config.SolverConfig.Ltl2tgbaPath prop.LTLMatrix with 
        | Success aut -> aut 
        | Fail err -> 
            config.Logger.LogN err.DebugInfo
            raise <| HyPlanException err.Info

    let pg, _ = compileToParityGame systemMap (set uTraceVariables) dpa

    pg