import Lean
import Mathlib
import Lean.Meta
import Lean.Elab
import Lean.Environment
import Lean.Data.Json
import Std.Data.HashMap
import Std.Data.HashSet
import Qq


open Lean Meta Elab Tactic IO Qq

def batchSize : Nat := 20

def db_name : String := "mathlibnew"

def rootDir : String := "/Users/princhern/Documents/LeanWithVersion/v4.17.0-rc1"

def csvPath : String := rootDir ++ "/backup.csv"

def skipped_records : String := rootDir ++ "/skipped_records.txt"

def header : String := "name,statement_str,expr_json,expr_cse_json\n"

def startbatch := 3413

def endbatch := 3413

def dbHost := "127.0.0.1"
def dbPort := "5432"
def dbName := "pfdb"
def dbUser := "postgres"
def dbPass := "mysecretpassword"
def maxJsonbSize := 268435455
def maxlength := 1000000


section

structure TheoremData where
  name : String
  statement_str : String
  expr_json : Json
  expr_cse_json : Json
deriving ToJson

def maxRetries : Nat := 0
def retryDelay : Nat := 3

inductive YourExpr where
  | bvar (deBruijnIndex : Nat)
  | fvar (fvarId : String)
  | mvar (mvarId : String)
  | sort (u : String)
  | const (declName : String) (us : List String)
  | app (fn : YourExpr) (arg : YourExpr)
  | lam (binderName : String) (binderType : YourExpr) (body : YourExpr) (binderInfo : String)
  | forallE (binderName : String) (binderType : YourExpr) (body : YourExpr) (binderInfo : String)
  | letE (declName : String) (type : YourExpr) (value : YourExpr) (body : YourExpr) (nonDep : Bool)
  | lit (literal : String)
  | mdata (data : String) (expr : YourExpr)
  | proj (typeName : String) (idx : Nat) (struct : YourExpr)
  deriving ToJson, FromJson, Repr, Inhabited


namespace ExprConversionNonRecNoMemo

structure Frame where
  node         : Expr
  childResults : Array YourExpr
  remaining    : List Expr
  deriving Inhabited

def getChildren (e : Expr) : List Expr :=
  match e with
  | Expr.app f a           => [f, a]
  | Expr.lam _ binderType body _         => [binderType, body]
  | Expr.forallE _ binderType body _       => [binderType, body]
  | Expr.letE _ type value body _          => [type, value, body]
  | Expr.mdata _ expr      => [expr]
  | Expr.proj _ _ struct   => [struct]
  | _                      => []

def reconstruct (e : Expr) (children : List YourExpr) : YourExpr :=
  match e with
  | Expr.bvar idx           => YourExpr.bvar idx
  | Expr.fvar fvarId        => YourExpr.fvar ((repr fvarId).pretty)
  | Expr.mvar mvarId        => YourExpr.mvar ((repr mvarId).pretty)
  | Expr.sort lvl           => YourExpr.sort ((repr lvl).pretty)
  | Expr.const declName us  => YourExpr.const (declName.toString) (us.map (fun lvl => (repr lvl).pretty))
  | Expr.app _ _            =>
      match children with
      | [f', a'] => YourExpr.app f' a'
      | _        => panic "Unexpected children count in app"
  | Expr.lam binderName _ _ binderInfo  =>
      match children with
      | [binderType', body'] => YourExpr.lam (binderName.toString) binderType' body' ((repr binderInfo).pretty)
      | _                    => panic "Unexpected children count in lam"
  | Expr.forallE binderName _ _ binderInfo =>
      match children with
      | [binderType', body'] => YourExpr.forallE (binderName.toString) binderType' body' ((repr binderInfo).pretty)
      | _                    => panic "Unexpected children count in forallE"
  | Expr.letE declName _ _ _ nonDep       =>
      match children with
      | [type', value', body'] => YourExpr.letE (declName.toString) type' value' body' nonDep
      | _                      => panic "Unexpected children count in letE"
  | Expr.lit lit            => YourExpr.lit ((repr lit).pretty)
  | Expr.mdata data _       =>
      match children with
      | [expr'] => YourExpr.mdata ((repr data).pretty) expr'
      | _       => panic "Unexpected children count in mdata"
  | Expr.proj typeName idx _  =>
      match children with
      | [struct'] => YourExpr.proj (typeName.toString) idx struct'
      | _         => panic "Unexpected children count in proj"

def exprToYourExprIter (e : Expr) (maxDepth : Nat) : IO YourExpr := do

  let mut stack : List Frame := []
  stack := [{ node := e, childResults := #[], remaining := getChildren e }]
  while !stack.isEmpty do
    if stack.length > maxDepth then
      throw $ IO.userError "Iteration depth exceeded threshold"
    let top := stack.head!
    match top.remaining with
    | child :: rest =>
         stack := { top with remaining := rest } :: stack.tail!
         stack := { node := child, childResults := #[], remaining := getChildren child } :: stack
    | [] =>
         let result := reconstruct top.node (top.childResults.toList)
         stack := stack.tail!
         if stack.isEmpty then
           return result
         else
           let parent := stack.head!
           let updatedParent := { parent with childResults := parent.childResults.push result }
           stack := updatedParent :: stack.tail!
  unreachable!

end ExprConversionNonRecNoMemo

def yourExprToJson (expr : YourExpr) : IO String :=
  let exprJson := toJson expr |> toString
  let exprJson := exprJson.replace "\\n" ""
  let exprJson := exprJson.replace " " ""
  pure $ exprJson

def myGetOrElse {ε α : Type} (ex : Except ε α) (default : α) : α :=
  match ex with
  | Except.ok a   => a
  | Except.error _ => default



def saveExprToJson (expr : Expr) (filePath : String) : IO Unit := do
  let yourExpr ← ExprConversionNonRecNoMemo.exprToYourExprIter expr 1000
  let jsonStr ← yourExprToJson yourExpr
  IO.FS.writeFile filePath jsonStr

end


def tmpDir : String := rootDir ++ "/tmp"

def initializeFiles : IO Unit := do

  IO.FS.createDirAll tmpDir
  for i in [0:batchSize] do
    let filePath : System.FilePath := tmpDir ++ s!"/file_{i}.csv"
    if ← filePath.pathExists then
      IO.FS.removeFile filePath
    IO.FS.writeFile filePath ""
  IO.FS.writeFile csvPath header

def resetFile (fileId : Nat) : IO Unit := do
  let filePath : System.FilePath := tmpDir ++ s!"/file_{fileId}.csv"
  IO.FS.writeFile filePath ""

def appendFile (path : System.FilePath) (contents : String) : IO Unit := do
  let h ← IO.FS.readFile path
  IO.FS.writeFile path (h ++ contents)

def backupToCSV (filePath : String) (data : Json) : IO Unit := do
  let name := data.getObjValD "name"
  let name : String := s!"{name}"
  let statement := data.getObjValD "statement_str"
  let statement : String := s!"{statement}"
  let exprJson := data.getObjValD "expr_json"

  let exprJson : String := s!"{exprJson}"

  let exprCseJson := data.getObjValD "expr_cse_json"
  let exprCseJson : String := s!"{exprCseJson}"

  if exprJson = "null" then
    let csvRow := s!"{name},{statement},{exprJson},{exprCseJson}\n"
    appendFile filePath csvRow
  else
    let exprJson := exprJson.drop 1 |>.dropRight 1 -- \"
    let exprJson := exprJson.replace "\\\\" "\\"
    let exprJson := exprJson.replace "\\\"" "\"" -- \" => "
    let exprJson := exprJson.replace "\\n" ""

    let escape (s : String) : String :=
      "\"" ++ (s.replace "\"" "\"\"") ++ "\""

    let csvRow := s!"{name},{statement},{escape exprJson},{escape exprCseJson}\n"

    appendFile filePath csvRow

def processTheoremTask (taskId : Nat) (name : Name) (decl : ConstantInfo) : IO Unit := do

  let filePath := tmpDir ++ s!"/file_{taskId}.csv"

  let type : Expr := decl.type
  let statementStr := type.dbgToString

  if statementStr.length > maxlength then
    appendFile skipped_records s!"Skipped: {name}\n"

    let data : Json := Json.mkObj [
      ("name", name.toString),
      ("statement_str", Json.null),
      ("expr_json", Json.null),
      ("expr_cse_json", Json.null)
    ]
    backupToCSV filePath data
  else

    let yourExpr ← ExprConversionNonRecNoMemo.exprToYourExprIter type 1000
    let exprJson ← yourExprToJson yourExpr
    let data : Json := Json.mkObj [
      ("name", name.toString),
      ("statement_str", statementStr),
      ("expr_json", exprJson),
      ("expr_cse_json", Json.null)
    ]
    backupToCSV filePath data



def processBatch_csv (table : String) : IO Unit := do
  let child ← IO.Process.spawn {
    cmd := "psql",
    args := #[
      "-U", dbUser,
      "-h", dbHost,
      "-p", dbPort,
      "-d", dbName,
      "-c", s!"\\copy \"{table}\" (name, statement_str, expr_json, expr_cse_json) FROM '{csvPath}' WITH (FORMAT csv, HEADER true);"
    ],
    env := #[("PGPASSWORD", dbPass)],
    stdin := IO.Process.Stdio.null,
    stdout := IO.Process.Stdio.piped,
    stderr := IO.Process.Stdio.piped
  }
  let exitCode ← child.wait
  let _ ← child.stdout.readToEnd
  let stderr ← child.stderr.readToEnd
  if exitCode = 0 then
    IO.println "Success"
  else
    IO.println s!"Fail: {stderr}"

  IO.FS.writeFile csvPath header
  pure ()

def insertIntoDB {d : Type} [ToJson d] (method : String) (name : String) (statement : String) (data : d) : IO Unit := do
  let data := toJson data

  let checkChild ← IO.Process.spawn {
    cmd := "psql",
    args := #[
      "-q",
      "-U", dbUser,
      "-h", dbHost,
      "-p", dbPort,
      "-d", dbName,
      "-t", "-c", s!"select 1 from \"{method}\" where name = '{name}' limit 1;"
    ],
    env := #[("PGPASSWORD", dbPass)],
    stdin := IO.Process.Stdio.null,
    stdout := IO.Process.Stdio.piped,
    stderr := IO.Process.Stdio.piped
  }
  let checkExit ← checkChild.wait
  let checkOut ← checkChild.stdout.readToEnd
  if checkExit == 0 ∧ checkOut.trim != "" then

    pure ()
  else


    let dataSize := (data |> toString).length
    if dataSize > maxJsonbSize then
      let logMessage := s!"Skipped: {name} (JSONB size: {dataSize} bytes)\n"
      appendFile skipped_records logMessage


      let data : Json := Json.mkObj [
        ("name", name),
        ("statement_str", statement),
        ("expr_json", Json.null),
        ("expr_cse_json", Json.null)
      ]
      backupToCSV csvPath data
      pure ()
    else backupToCSV csvPath data


def extractBetweenBQuotes (s : String) : String :=
  let start := s.find (· = '\\')
  if start == s.endPos then
    ""
  else
    let afterB := s.drop (start.1 + 1)
    if afterB.get? 0 == some '"' then
      let rest := afterB.drop 1
      let endPos := rest.find (· = '"')
      if endPos == rest.endPos then
        ""
      else
        rest.take endPos.1
    else
      ""

def escapeChars (s : String) : String :=
  s.foldl (fun acc c =>
    if c == '\\' || c == '\"' then
      acc ++ "\\" ++ c.toString
    else
      acc ++ c.toString
  ) ""

def processInput (s : String) : String :=
  let parts := s.splitOn "\\\""
  let processedParts : List String := parts.zipIdx.map (fun (part, i) =>
    if i % 2 == 1 then
      "\"" ++ escapeChars part ++ "\""
    else
      part
  )
  String.join processedParts

def mergeBatchFiles : IO Unit := do

  for i in [0 : batchSize] do
    let filePath := tmpDir ++ s!"/file_{i}.csv"
    let content ← IO.FS.readFile filePath

    appendFile csvPath (content)
    resetFile i
  processBatch_csv db_name

def processAllTheorems : MetaM Unit := do
  let env ← getEnv
  let theorems := env.constants.toList.filter (·.2.isTheorem)
  let totalBatches := (theorems.length + batchSize - 1) / batchSize
  let declCount := theorems.length

  IO.println s!"thm {declCount}"

  let startBatch := startbatch - 1
  let endBatch :=  min endbatch (totalBatches - 1)

  for batchNum in [startBatch : endBatch] do


    let batchTheorems := theorems.drop (batchNum * batchSize) |>.take batchSize


    let tasks ← batchTheorems.zipIdx.mapM fun ⟨(name, decl), taskId⟩ =>
      IO.asTask (processTheoremTask taskId name decl)


    for task in tasks do
      let result := task.get
      match result with
      | .ok _ => pure ()
      | .error e => throwError m!"{e}"


    mergeBatchFiles
    IO.println s!"Finish {batchNum + 1}/{totalBatches}"

def main : IO Unit := do

  Lean.initSearchPath (← Lean.findSysroot) ["build/lib"]
  let imports := #[⟨`Mathlib, false⟩]
  let env ← Lean.importModules imports {}


  initializeFiles


  let ctx : Core.Context := {
    fileName := "",
    fileMap := ⟨"", #[]⟩,
    options := {}
  }
  let state : Core.State := {
    env := env,
    nextMacroScope := 0,
    traceState := {}
  }
  let t1 ← IO.monoMsNow
  let _ ← processAllTheorems.toIO ctx state
  let t2 ← IO.monoMsNow
  IO.println s!"time: {t2 - t1}"
