import os
from pathlib import Path
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import runner as R 

def _configure_runner(args: R.Args):

    R.args = args

    # sqrt
    R.alpha = lambda t: args.alpha0 / jnp.sqrt(t / args.decrease_factor + 1.0)
    R.beta  = lambda t: R.alpha(t)

    # Quadratic
    a = jnp.sqrt(args.L**2 - args.L**4 * args.rho ** 2)
    b = args.L**2 * args.rho

    def L(x,y):
        x = x-args.offset
        y = y-args.offset
        return a * x.transpose().dot(y) + b/2 * jnp.sum(x**2, axis=-1) - b/2 * jnp.sum(y**2, axis=-1)
    
    R.L = L
    R.F = R.make_F(R.L)

    if args.noise_model == "gaussian":
        def Fhat(z, xi):
            return R.F(z) + xi

        def sample_xi(globalkey):
            globalkey, subkey = jax.random.split(globalkey)
            noise = args.noise * jax.random.normal(subkey, shape=(args.batch_size, args.d*2))

            return noise, globalkey
        
    elif args.noise_model == "laplace":
        def Fhat(z, xi):
            return R.F(z) + xi

        def sample_xi(globalkey):
            globalkey, subkey = jax.random.split(globalkey)
            noise = args.noise * jax.random.laplace(subkey, shape=(args.batch_size, args.d*2))

            return noise, globalkey

    elif args.noise_model == "t":
        def Fhat(z, xi):
            return R.F(z) + xi

        def sample_xi(globalkey):
            globalkey, subkey = jax.random.split(globalkey)
            noise = args.noise * jax.random.t(subkey, 2.0, shape=(args.batch_size, args.d*2))

            return noise, globalkey

    R.Fhat = Fhat
    R.sample_xi = sample_xi

def _init_history_zero(keys, length):
    return {k: jnp.arange(length, dtype="float32") for k in keys}

def run_experiment(args: R.Args):
    args = args.setup()
    _configure_runner(args)

    globalkey = jax.random.PRNGKey(args.seed)

    if args.d == 1:
        z0 = jnp.array([args.init])
    else:
        globalkey, subkey = jax.random.split(globalkey)
        z0 = jax.random.uniform(subkey, (args.batch_size, args.d * 2))
    z = z0
    zprev = z
    zbar = z
    gbar = z

    keys = ['squared operator norm']
    history = _init_history_zero(keys, args.T)

    # Srart Loop
    init_state = (history, z0, z, zbar, zprev, gbar, globalkey)
    body = (lambda g: (lambda t, s: (R.loop_body_bcseg if args.method=="PFL+23" else R.loop_body_halp)(t, s, g)))(jnp.asarray(args.gamma))
    state = jax.lax.fori_loop(1, args.T, body, init_state)

    history, z0, z, zbar, zprev, gbar, globalkey = state

    return history, z
