using Format
using CSV
using Printf
using DataFrames
using SplitApplyCombine
using ProgressBars
using DataFramesMeta
using StatsBase
using IterTools
using ProgressBars
using LinearAlgebra
using Functional


lstdθ(A::Matrix{Float64}, b::Vector{Float64}; α=0.) = (A + α * I) \ b

struct LSTDResult
    θ::Vector{Float64}
    η::Float64
    α::Float64
end

function lstd_off_path(ss, snews, rs)
    X = hcat(ss, ones(size(ss, 1)))
    Xnew = hcat(snews, zeros(size(ss, 1)))
    A = X' * (X - Xnew)
    b = X'rs
    [let θη = A \ b;
         LSTDResult(θη[1:end-1], θη[end], 0.)
     end
     for α in αs]
end

function lstd_path(ss, snews, rs; αs=[0.])
    A = ss' * (ss - snews)
    rbar = mean(rs)
    b = sum(ss .* (rs .- rbar); dims=1) |> vec

    λ = ones(size(A, 1))
    # λ[end] = 1e-5

    [let θη = (A + α * diagm(λ)) \ b;
         LSTDResult(θη, rbar, α)
     end
     for α in αs]
end

function dq_lstd(ss, rs, snews, sstr, ssco; αs=[1.])
    Δxbar = vec(mean(sstr; dims=1) .- mean(ssco; dims=1))
    lstds = lstd_path(ss, snews, rs; αs=αs)
    [Dict(:est => result.θ'Δxbar,
          :estimator => "DQ",
          :α => result.α)
     for result in lstds]
end

function ope_lstd(ss, as, rs, sstr, ssco)
    idx1 = findall(as[1:end-1] .== 1)
    idx2 = findall(as[1:end-1] .== 2)
    ss1, rs1, snews1 = ss[idx1, :], rs[idx1], ssco[idx1 .+ 1, :]
    ss2, rs2, snews2 = ss[idx2, :], rs[idx2], sstr[idx2 .+ 1, :]
    lstds1 = lstd_off_path(ss1, snews1, rs1)
    lstds2 = lstd_off_path(ss2, snews2, rs2)
    [Dict(:est => tr.η - co.η,
          :estimator => "Off-Policy")
     for (tr, co) in zip(lstds2, lstds1)]
end

function run_dynkin(ss, as, rs, sstr, ssco; αs=[1.], T0=1)
    summ_ts = round.(collect(1.1 .^ (1:1000))) |> unique
    summ_ts = Int.(summ_ts[summ_ts .< length(rs)])
    results = chain((
        [merge(Dict(:t => t), result)
         for result in dynkin_lstd(ss[T0:t-1, :], rs[T0:t-1, :],
                                   ss[T0+1:t, :],
                                   sstr[T0:t-1, :],
                                   ssco[T0:t-1, :]; αs=αs)]
        for t in tqdm(summ_ts)
            if t > T0)...)
    df = DataFrame(collect(results))
    df
end

function run_ope(ss, as, rs, sstr, ssco)
    summ_ts = round.(collect(1.1 .^ (1:1000))) |> unique
    summ_ts = Int.(summ_ts[summ_ts .< length(rs)])
    results = chain((
        [merge(Dict(:t => t), result)
         for result in ope_lstd(ss[1:t, :],
                                as[1:t, :],
                                rs[1:t, :],
                                sstr[1:t, :],
                                ssco[1:t, :])]
        for t in tqdm(summ_ts))...)
    df = DataFrame(collect(results))
    df
end

function estimate_csv(dir)
    ss, as, rs, trans = read_file(joinpath(dir, "actual.csv"))
    sstr, _, _, _ =  read_file(joinpath(dir, "treatment.csv"); trans=trans)
    ssco, _, _, _ =  read_file(joinpath(dir, "control.csv"); trans=trans)
    summ_ts = round.(collect(1.1 .^ (1:1000))) |> unique
    summ_ts = Int.(summ_ts[summ_ts .< length(rs)])
    αs = [1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6]
    dynkins = run_dynkin(ss, as, rs, sstr, ssco; αs=αs)
    opes = run_ope(ss, as, rs, sstr, ssco)
    naive_running =
        cumsum(rs .* (as .- 1); dims=1) ./ cumsum(as .- 1; dims=1) .-
        cumsum(rs .* (2 .- as); dims=1) ./ cumsum(2 .- as; dims=1)
    naive = DataFrame([Dict(:t => idx,
                            :est => naive_running[idx],
                            :estimator => "Naive")
                       for idx in summ_ts])
    vcat(dynkins, opes, naive; cols=:union)
end

normalized_xcols = [
    # State of the world
    # "drivers_total_0",
    # "drivers_total_1",
    # "drivers_total_2",
    # "drivers_total_3",
    "drivers_total",
    "profit",
    "action_indicator_tr",
]

function read_file(fname; trans=nothing, T0=20000)
    df = fname |> CSV.File |> DataFrame
    df[isinf.(df.price), :price] .= 0.
    df = df[T0:end, :]
    @show normalized_xcols
    # "drivers_interacted_co",
    ss = Matrix{Float64}(df[!, normalized_xcols])
    if trans === nothing
        trans = fit(ZScoreTransform, ss; dims=1)
    end
    # ss_nn = Matrix{Float64}(df[!, non_normalized_xcols])
    ss = hcat(StatsBase.transform(trans, ss)) # , ss_nn)
              # ones(size(ss, 1))) # I don't think intercept does anything
    as = df.is_treated .+ 1
    rs = df.profit
    return ss, as, rs, trans
end
