using Revise
using StatsBase
using Random
using MDPs
using MetaMDPs
using Flux
using CUDA
using BSON
using PPO
using Dates
using Bandits
using Transformers
import EllipsisNotation
import PPO: rollouts_parallel
using Flux: glorot_normal, orthogonal, Recur, zeros32

const TEST_SEED = 42


mutable struct WithIncrementalCaching
    model
    prev_input_shape
    prev_output
end
Flux.@functor WithIncrementalCaching (model,)
function WithIncrementalCaching(model)
    return WithIncrementalCaching(model, (0, 0, 0), nothing)
end
function Flux.reset!(m::WithIncrementalCaching)
    Flux.reset!(m.model)
    m.prev_input_shape = (0, 0, 0)
    Flux.Zygote.@ignore if isa(m.prev_output, CUDA.CuArray); CUDA.unsafe_free!(m.prev_output); end
    m.prev_output = nothing
end
function (m::WithIncrementalCaching)(x)
    D, L, Bs = size(x, 1), size(x, 2), size(x)[3:end]
    prev_D, prev_L, prev_Bs = m.prev_input_shape
    if D != prev_D || L != prev_L + 1 || Bs != prev_Bs
        # @info "Disabling incremental caching" D L Bs prev_D prev_L prev_Bs
        Flux.Zygote.@ignore if isa(m.prev_output, CUDA.CuArray); CUDA.unsafe_free!(m.prev_output); end
        y = m.model(x)
    else
        x_new = selectdim(x, 2, L:L) |> copy
        y_new = m.model(x_new)
        y = cat(m.prev_output, y_new, dims=2)
        Flux.Zygote.@ignore if isa(x_new, CUDA.CuArray); CUDA.unsafe_free!(x_new); end
        Flux.Zygote.@ignore if isa(y_new, CUDA.CuArray); CUDA.unsafe_free!(y_new); end
    end
    m.prev_input_shape = (D, L, Bs)
    Flux.Zygote.@ignore if isa(m.prev_output, CUDA.CuArray); CUDA.unsafe_free!(m.prev_output); end
    m.prev_output = y
    return y
end

function make_random_mdps(; seed=0, nstates=10, nactions=5, task_horizon=10, horizon=100, gamma=1, ood=false, random_mdps_dirichlet_alpha=1.0, random_mdps_rewards_std=1.0, kwargs...)
    m, n = nstates, nactions
    TH = task_horizon
    H = horizon
    if ood
        random_mdps_dirichlet_alpha = 0.25
    end
    mdps = MDPGenerator((i, rng) -> RandomDiscreteMDP(rng, m, n; uniform_dist_rewards=false, α=random_mdps_dirichlet_alpha, β=random_mdps_rewards_std), Xoshiro(seed))
    mdps_test = MDPGenerator((i, rng) -> RandomDiscreteMDP(rng, m, n; uniform_dist_rewards=false, α=random_mdps_dirichlet_alpha, β=random_mdps_rewards_std), Xoshiro(TEST_SEED * (seed + 1)))
    γ = gamma
    sspace = state_space(mdps)
    aspace = action_space(mdps)
    m, n = size(sspace, 1), size(aspace, 1)
    return mdps, mdps_test, TH, γ, H, sspace, aspace, m, n
end

function make_bandits(; seed=0, narms=5, horizon=100, gamma=1, ood=false, kwargs...)
    k = narms
    H, TH = horizon, 1
    if ood
        # use normal distribution (mean = 0.5, std = 0.5) to generate bandit success probabilities
        mdps = MDPGenerator((i, rng) -> BernauliMultiArmedBandit(randn(rng, k) .* 0.5 .+ 0.5), Xoshiro(seed))
        mdps_test = MDPGenerator((i, rng) -> BernauliMultiArmedBandit(randn(rng, k) .* 0.5 .+ 0.5), Xoshiro(TEST_SEED))
    else
        # use uniform distribution to generate bandit success probabilities
        mdps = MDPGenerator((i, rng) -> BernauliMultiArmedBandit(rand(rng, k)), Xoshiro(seed))
        mdps_test = MDPGenerator((i, rng) -> BernauliMultiArmedBandit(rand(rng, k)), Xoshiro(TEST_SEED))
    end
    γ = gamma
    sspace = state_space(mdps)
    aspace = action_space(mdps)
    m, n = size(sspace, 1), size(aspace, 1)
    # @info "created random bandits" TH, γ, H, m, n
    return mdps, mdps_test, TH, γ, H, sspace, aspace, m, n
end

function make_gridworlds(; seed=0, horizon=200, task_horizon=horizon, gamma=1, variation="11x11", kwargs...)
    H, TH = horizon, task_horizon
    variation_dict = Dict(
        "11x11" => CUSTOMGW_11x11_PARAMS,
        "11x11_deterministic" => CUSTOMGW_11x11_PARAMS_DETERMINISTIC,
        "13x13" => CUSTOMGW_13x13_PARAMS,
        "13x13_dense" => CUSTOMGW_13x13_PARAMS_DENSE,
        "13x13_deterministic" => CUSTOMGW_13x13_PARAMS_DETERMINISTIC,
        "13x13_watery" => CUSTOMGW_13x13_PARAMS_WATERY,
        "13x13_dangerous" => CUSTOMGW_13x13_PARAMS_DANGEROUS,
        "13x13_corner" => CUSTOMGW_13x13_PARAMS_CORNER,
    )
    grid_args = variation_dict[variation]
    mdps = MDPGenerator((i, rng) -> GridWorldContinuous{Float32}(CustomGridWorld(rng, grid_args)), Xoshiro(seed))
    mdps_test = MDPGenerator((i, rng) -> GridWorldContinuous{Float32}(CustomGridWorld(rng, grid_args)), Xoshiro(TEST_SEED))
    γ = gamma
    sspace = state_space(mdps)
    aspace = action_space(mdps)
    m, n = size(sspace, 1), size(aspace, 1)
    return mdps, mdps_test, TH, γ, H, sspace, aspace, m, n
end


"""Returns another problem"""
function wrap_onehot_mdps(problem)
    (mdps, mdps_test, TH, γ, H, sspace, aspace, m, n) = problem
    mdps = Iterators.map(OneHotStateReprWrapper{Float32}, mdps);
    mdps_test = Iterators.map(OneHotStateReprWrapper{Float32}, mdps_test);
    mdp1, _ = iterate(mdps)
    sspace, aspace = state_space(mdp1), action_space(mdp1);
    m, n = size(sspace, 1), size(aspace, 1)
    # @info "wrapped OneHotStateRepr" TH, γ, H, m, n
    return mdps, mdps_test, TH, γ, H, sspace, aspace, m, n
end

"""Wrap value augmented mdps and return another problem"""
function wrap_VAMDPs(problem; task_horizon=Inf, abstraction_radius=0, abstraction_cluster_size=1, action_num_bins=nothing, Q_DENOM, VI_EP, drop_observation=false)
    # println("action_num_bins: $action_num_bins")
    (mdps, mdps_test, TH, γ, H, sspace, aspace, m, n) = problem
    mdps = Iterators.map(m -> ValueAugmentedMDP(m; task_horizon=task_horizon, abstraction_radius=abstraction_radius, abstraction_cluster_size=abstraction_cluster_size, action_num_bins=action_num_bins, Q_DENOM=Q_DENOM, VI_EP=VI_EP, drop_observation=drop_observation), mdps);
    mdps_test = Iterators.map(m -> ValueAugmentedMDP(m; task_horizon=task_horizon, abstraction_radius=abstraction_radius, abstraction_cluster_size=abstraction_cluster_size, action_num_bins=action_num_bins, Q_DENOM=Q_DENOM, VI_EP=VI_EP, drop_observation=drop_observation), mdps_test);
    mdp1, _ = iterate(mdps)
    sspace, aspace = state_space(mdp1), action_space(mdp1);
    m, n = size(sspace, 1), size(aspace, 1);
    # @info "wrapped VA wrapper" TH, γ, H, m, n
    return mdps, mdps_test, TH, γ, H, sspace, aspace, m, n
end


function make_metamdp(problem; include_time_context)
    (mdps, mdps_test, TH, γ, H, sspace, aspace, m, n) = problem
    metamdp, metamdp_test = MetaMDP(mdps, H, include_time_context; task_horizon=TH), MetaMDP(mdps_test, H, include_time_context; task_horizon=TH)
    sspace, aspace = state_space(metamdp), action_space(metamdp)
    m, n = size(sspace, 1), size(aspace, 1);
    return metamdp, metamdp_test, (mdps, mdps_test, TH, γ, H, sspace, aspace, m, n)
end

function test_random_policy(problem; test_episodes=1000, kwargs...)
    (mdps, mdps_test, TH, γ, H, sspace, aspace, m, n) = problem
    metamdp, metamdp_test = make_metamdp(problem; include_time_context=:none)
    score = interact(metamdp_test, RandomPolicy(metamdp_test), γ, H, test_episodes, ProgressMeterHook(; desc="Evaluating Random Policy"); rng=Xoshiro(TEST_SEED))[1] |> mean
    return score, [], []
end

function do_ppo_learning(project_name, experiment_name, problem_set, iters; problem_set_batch, model, dmodel, lr, log_interval, model_save_interval, nsteps, nepochs, ent_bonus, kl_target, ppo_epsilon, lambda, seed,  advantagenorm, device, inference_device, adam_eps, adam_wd, clipnorm, minibatch_size, progressmeter, iters_per_postepisode, video, video_interval, act_greedy, nheads, ndecoders, test_model, continue_model, decay_ent_bonus, decay_lr, config, problem_name, obsnorm, rewardnorm, no_multithreading, no_plots, test_episodes, include_time_context, no_pe, no_decoder, no_evidence_wrapper, parallel_testing, algo, kwargs...)
    metamdp, metamdp_test, meta_problem = make_metamdp(problem_set; include_time_context=include_time_context)
    (mdps, mdps_test, TH, γ, H, sspace, aspace, m, n) = meta_problem
    @info "created meta_mdp" m, n
    T, Tₐ = Float32, eltype(eltype(aspace))

    norm_by_rew = false

    if !no_evidence_wrapper
        metamdp = EvidenceObservationWrapper{T}(metamdp)
    end
    if obsnorm || rewardnorm
        metamdp = NormalizeWrapper(metamdp, normalize_reward=rewardnorm, normalize_obs=obsnorm, normalize_reward_by_reward_std=norm_by_rew)
        metamdp.update_stats = false
        obs_rmv, ret_rmv, rew_rmv = metamdp.obs_rmv, metamdp.ret_rmv, metamdp.rew_rmv
    end

    if !no_evidence_wrapper
        metamdp_test = EvidenceObservationWrapper{T}(metamdp_test)
    end
    if obsnorm || rewardnorm
        metamdp_test = NormalizeWrapper(metamdp_test, obs_rmv=obs_rmv, ret_rmv=ret_rmv, rew_rmv=rew_rmv, normalize_reward=rewardnorm, normalize_obs=obsnorm, normalize_reward_by_reward_std=norm_by_rew)
        metamdp_test.update_stats = false
    end

    metamdps_batch = map(problem_set_batch) do _pb
        _metamdp = make_metamdp(_pb; include_time_context=include_time_context)[1]
        if !no_evidence_wrapper
            _metamdp = EvidenceObservationWrapper{T}(_metamdp)
        end
        _metamdp = obsnorm || rewardnorm ? NormalizeWrapper(_metamdp, obs_rmv=obs_rmv, ret_rmv=ret_rmv, rew_rmv=rew_rmv, normalize_reward=rewardnorm, normalize_obs=obsnorm, normalize_reward_by_reward_std=norm_by_rew) : _metamdp
        return _metamdp
    end
    metamdps_batch = VecEnv(metamdps_batch, !no_multithreading)

    get_reward_multipler_fn() = rewardnorm ? (norm_by_rew ? std(rew_rmv) : std(ret_rmv)) : 1.0

    name = experiment_name
    PPOActor = aspace isa IntegerSpace ? PPOActorDiscrete{T} : PPOActorContinuous{T, Tₐ}


    m, n = size(state_space(metamdp), 1), size(action_space(metamdp), 1)
    if test_model == ""
        if model == "transformer"
            dim_k = dmodel ÷ nheads
            dim_v = dim_k
            dim_ff = 4 * dmodel
            # dim_ff = dmodel
            # MakeProjectToDimModelMDP() = Chain(Dense(m, dim_ff, relu), Dense(dim_ff, dim_ff, relu), Dense(dim_ff, dmodel, relu))
            MakeProjectToDimModelMDP() = Chain(Dense(m, dim_ff, relu), Dense(dim_ff, dmodel, relu))
            MakeProjectToDimModel = MakeProjectToDimModelMDP
            MakePositionalEncoder(incremental_inference_mode) = no_pe ? identity : LearnedPositionalEncoder(dmodel, H+1; incremental_inference_mode=incremental_inference_mode)
            MakeDecoder(incremental_inference_mode) = no_decoder ? identity : Decoder(dmodel, dim_k, dim_v, nheads, dim_ff, ndecoders; dropout=false, no_encoder=true, incremental_inference_mode=incremental_inference_mode)
            MakePreFinalLayer() = no_decoder ? Dense(dmodel, dmodel, relu) : relu  # relu required since the final layer in the decoder is a linear layer with no activation
            if continue_model != ""
                _actor_model, critic_model = loadmodels(continue_model)
            else
                _actor_model = WithIncrementalCaching(
                    Chain(
                        MakeProjectToDimModel(),
                        LayerNorm(dmodel, affine=true),
                        MakePositionalEncoder(true),
                        MakeDecoder(true),
                        MakePreFinalLayer(),
                        Dense(dmodel, n)
                    )
                )
                critic_model = Chain(
                    MakeProjectToDimModel(),
                    LayerNorm(dmodel, affine=true),
                    MakePositionalEncoder(false),
                    MakeDecoder(false),
                    MakePreFinalLayer(),
                    Dense(dmodel, 1)
                )
            end
            actor_model = _actor_model |> inference_device
            p = PPOActor(actor_model, false, aspace, TRANSFORMER)
            gp = PPOActor(actor_model, true, aspace, TRANSFORMER)
        elseif model == "rnn"
            actor_model = Chain(Dense(m, dmodel, relu), GRUv3(dmodel, 4*dmodel), Dense(4*dmodel, n)) |> inference_device
            critic_model = Chain(Dense(m, dmodel, relu), GRUv3(dmodel, 4*dmodel), Dense(4*dmodel, 1))
            p = PPOActor(actor_model, false, aspace, RECURRENT)
            gp = PPOActor(actor_model, true, aspace, RECURRENT)
        elseif model == "markov"
            dim_ff = 4 * dmodel
            _actor_model = Chain(Dense(m, dim_ff, relu), Dense(dim_ff, dim_ff, relu), Dense(dim_ff, n)) |> inference_device
            critic_model = Chain(Dense(m, dim_ff, relu), Dense(dim_ff, dim_ff, relu), Dense(dim_ff, 1))
            actor_model = _actor_model |> inference_device
            p = PPOActor(actor_model, false, aspace, MARKOV)
            gp = PPOActor(actor_model, true, aspace, MARKOV)
        else
            error("What's a $model model?")
        end

        println("Actor model: ")
        display(_actor_model)
        println("Critic model: ")
        display(critic_model)

        ppol = PPOLearner(envs = metamdps_batch, actor=p, critic=critic_model, nsteps=nsteps, batch_size = minibatch_size, nepochs=nepochs, entropy_bonus=ent_bonus, decay_ent_bonus=decay_ent_bonus, clipnorm=clipnorm, normalize_advantages = advantagenorm, lr_critic=lr, lr_actor=lr, decay_lr=decay_lr, min_lr=1f-5, device=device, ppo=true, kl_target=kl_target, ϵ=ppo_epsilon, λ=lambda, adam_epsilon=adam_eps, adam_weight_decay=adam_wd, early_stop_critic=false, iters_per_postepisode=iters_per_postepisode, progressmeter=progressmeter)

        get_stats() = Dict(ppol.stats..., :reward_multiplier => get_reward_multipler_fn())
        drh = DataRecorderHook(get_stats, "data/$project_name/$name.csv", overwrite=false)
        ph = no_plots ? EmptyHook() : PlotEverythingHook("data/$project_name", "plots/$project_name")
        video_dir = "videos/$project_name/$name"
        vrh = video ? VideoRecorderHook(video_dir, ceil(Int, video_interval / iters_per_postepisode); vmax=100) : EmptyHook()
        act_policy = act_greedy ? gp : p
        rs, ls = interact(metamdp, act_policy, γ, H, iters ÷ iters_per_postepisode, ppol, LoggingHook(get_stats; smooth_over=1000), ProgressMeterHook(), drh, ph, vrh, ModelsSaveHook((actor_model, critic_model), "models/$project_name/$name", model_save_interval ÷ iters_per_postepisode), GCHook(), SleepHook(0.1); rng=Xoshiro(seed), reward_multiplier=get_reward_multipler_fn)
    else
        _actor_model, _critic_model = loadmodels(test_model)
        println("Model: ")
        display(_critic_model)

        testing_device = parallel_testing ? device : inference_device
        actor_model = _actor_model |> testing_device
        
        @info "Loaded models"
        if model == "transformer"
            p = PPOActor(actor_model, false, aspace, TRANSFORMER)
            gp = PPOActor(actor_model, true, aspace, TRANSFORMER)
        elseif model == "rnn"
            p = PPOActor(actor_model, false, aspace, RECURRENT)
            gp = PPOActor(actor_model, true, aspace, RECURRENT)
        elseif model == "markov"
            p = PPOActor(actor_model, false, aspace, MARKOV)
            gp = PPOActor(actor_model, true, aspace, MARKOV)
        else
            error("What's a $model?")
        end 
    end

    println("Testing policy. Greedy=$(act_greedy)")
    test_policy = act_greedy ? gp : p

    if parallel_testing
        metamdps_batch_test = map(problem_set_batch) do _pb
            _metamdp_test = make_metamdp(_pb; include_time_context=include_time_context)[2]
            if !no_evidence_wrapper
                _metamdp_test = EvidenceObservationWrapper{T}(_metamdp_test)
            end
            _metamdp_test = obsnorm || rewardnorm ? NormalizeWrapper(_metamdp_test, obs_rmv=obs_rmv, ret_rmv=ret_rmv, rew_rmv=rew_rmv, normalize_reward=rewardnorm, normalize_obs=obsnorm, normalize_reward_by_reward_std=norm_by_rew) : _metamdp_test
            if isa(_metamdp_test, NormalizeWrapper)
                _metamdp_test.update_stats = false
            end
            return _metamdp_test
        end
        metamdps_batch_test_venv = VecEnv(metamdps_batch_test, !no_multithreading)
        factory_reset!(metamdps_batch_test_venv)
        if H <= 1024
            _, _, 𝐫, _, _ = rollouts_parallel(metamdps_batch_test_venv, H, test_policy, testing_device, Xoshiro(TEST_SEED), true)
        else
            # split batch into two halves
            half_batch_size = div(length(metamdps_batch_test), 2)
            metamdps_batch_test_venv_1 = VecEnv(metamdps_batch_test[1:half_batch_size], !no_multithreading)
            metamdps_batch_test_venv_2 = VecEnv(metamdps_batch_test[half_batch_size+1:end], !no_multithreading)
            println("First half")
            _, _, 𝐫1, _, _ = rollouts_parallel(metamdps_batch_test_venv_1, H, test_policy, testing_device, Xoshiro(TEST_SEED), true)
            println("Second half")
            _, _, 𝐫2, _, _ = rollouts_parallel(metamdps_batch_test_venv_2, H, test_policy, testing_device, Xoshiro(TEST_SEED), true)
            𝐫 = cat(𝐫1, 𝐫2, dims=3)
        end
        Rs = sum(𝐫, dims=2)[:]
    else
        vrh = video ? VideoRecorderHook("videos/$project_name/$name-test", video_interval; vmax=100) : EmptyHook()
        Rs = interact(metamdp_test, test_policy, γ, H, test_episodes, ProgressMeterHook(), vrh, GCHook(), SleepHook(0.01); rng=Xoshiro(TEST_SEED), reward_multiplier=get_reward_multipler_fn)[1];
    end

    score = mean(Rs)
    score_std = std(Rs)
    score_ste = score_std / sqrt(length(Rs))
 
    println("Final score ", score, "±", score_ste, " (std ", score_std, ")", " (n=", length(Rs), ")")

    return score
end
