#%%
import os
from datetime import datetime
from pathlib import Path
from typing import Sequence
import matplotlib.pyplot as plt
from datargs import argsclass, parse
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import jax.nn as jnn

plt.rcParams['figure.facecolor'] = 'white'


@argsclass
class Args:
    T: int = 10_000
    d: int = 1
    batch_size: int = 1
    gamma: float = 1.0
    alpha0: float = 1/18
    beta0: float = 1
    seed: int = None
    method: str = "BCSEG+"
    decrease : str = "linear"
    decrease_factor: float = 1
    projection: str = None
    radius: int = 1

    problem: str = "quadratic"
    L: float = 1.0
    rho: float = -1/20
    init: Sequence[float] = (0.5, 0.5)
    offset: float = 0.0
    
    noise: float = 0.1
    noise_model: str = "additive"

    plot_field: bool = False
    name: str = None

    def setup(self):
        if self.seed is None:
            random_data = os.urandom(4) 
            self.seed = int.from_bytes(random_data, byteorder="big")
        return self

    def path(self, filename):
        return os.path.join(self.dir, filename)


def init_history(keys=[], length=0):
    history = {}
    for key in keys:
        history[key] = jnp.arange(length, dtype="float32")
    return history

def update_history(history, t, values: dict):
    for k,v in values.items():
        if k in history:
            history[k] = history[k].at[t].set(v)
    return history

args = Args()

#%%

if __name__ == '__main__':
    args = parse(Args).setup()

#%%

# Setup
args.setup()

def create_problem(args):
    if args.problem == "quadratic":
        a = jnp.sqrt(args.L**2 - args.L**4 * args.rho ** 2)
        b = args.L**2 * args.rho
        def L(x,y):
            x = x-args.offset
            y = y-args.offset
            return a * x.transpose().dot(y) + b/2 * jnp.sum(x**2, axis=-1) - b/2 * jnp.sum(y**2, axis=-1)
    return L

def make_F(L):
    Fx = jit(vmap(grad(L, argnums=0)))
    Fy = jit(vmap(grad(L, argnums=1)))
    def F(z):
        x,y = z[:, :args.d],z[:, args.d:]
        return jnp.concatenate([Fx(x,y), -Fy(x,y)], axis=-1)
    return F


L = create_problem(args)
F = make_F(L)

# Stochastic
globalkey = jax.random.PRNGKey(args.seed)
    
if args.noise_model == "gaussian":
    def Fhat(z, xi):
        return F(z) + xi

    def sample_xi(globalkey):
        globalkey, subkey = jax.random.split(globalkey)
        noise = args.noise * jax.random.normal(subkey, shape=(args.batch_size, args.d*2))

        return noise, globalkey
    
elif args.noise_model == "laplace":
    def Fhat(z, xi):
        return F(z) + xi

    def sample_xi(globalkey):
        globalkey, subkey = jax.random.split(globalkey)
        noise = args.noise * jax.random.laplace(subkey, shape=(args.batch_size, args.d*2))

        return noise, globalkey

elif args.noise_model == "t":
    def Fhat(z, xi):
        return F(z) + xi

    def sample_xi(globalkey):
        globalkey, subkey = jax.random.split(globalkey)
        noise = args.noise * jax.random.t(subkey, 2.0, shape=(args.batch_size, args.d*2))

        return noise, globalkey
    
# Projection
if args.projection == 'linf':
    def proj(z):
        return jnp.clip(z,  -args.radius, args.radius)
elif args.projection is None:
    proj = lambda z: z

# Initialize
if args.d == 1:
    z0 = jnp.array([args.init])
else:
    globalkey, subkey = jax.random.split(globalkey)
    z0 = jax.random.uniform(subkey, (args.batch_size, args.d*2))
    
z = z0
zprev = z
zbar = z
gbar = z

#%%

if args.decrease == "sqrt":
    alpha = lambda t: args.alpha0/jnp.sqrt(t/args.decrease_factor+1)

beta = lambda t: alpha(t)


def loop_body_bcseg(t, state, gamma):
    history, z0, z, zbar, zprev, gbar, globalkey = state

    entry = {
        'squared operator norm': jnp.linalg.norm(F(z)) ** 2
    }
    if args.d == 1:
        entry['x'] = z[:,0]
        entry['y'] = z[:,1]
    history = update_history(history, t, entry)

    xi, globalkey = sample_xi(globalkey)
    zbar = z - gamma * Fhat(z, xi) + (1 - beta(t)) * (zbar - zprev + gamma * Fhat(zprev, xi))
    zprev = z
    xibar, globalkey = sample_xi(globalkey)
    z = z - alpha(t) * gamma * Fhat(zbar, xibar)
    
    state = (history, z0, z, zbar, zprev, gbar, globalkey)

    state = jax.lax.cond(
        (t % 100000) == 0,
        lambda s: (
            jax.debug.print("TP-t = {}, gamma = {}, norm(F(z))^2 = {}", t, gamma, jnp.linalg.norm(F(z)) ** 2),
            s
        )[1],
        lambda s: s,
        state
    )

    return state
    
def loop_body_halp(t, state, gamma):
    history, z0, z, zbar, zprev, gbar, globalkey = state

    entry = {
        'squared operator norm': jnp.linalg.norm(F(z)) ** 2
    }
    if args.d == 1:
        entry['x'] = z[:,0]
        entry['y'] = z[:,1]
    history = update_history(history, t, entry)
    
    z_p = z
    xi, globalkey = sample_xi(globalkey)
    zbar = 1/(t+3) * z0 + (1-1/(t+3)) * z
    zhalf = zbar - gamma * gbar
    z = zbar - alpha(t) * (zbar - zhalf - gamma * gbar + gamma * Fhat(zhalf, xi))

    xibar, globalkey = sample_xi(globalkey)
    gbar = Fhat(z, xibar) + (1-beta(t)) * (gbar - Fhat(z_p, xibar))
    
    state = (history, z0, z, zbar, zprev, gbar, globalkey)

    state = jax.lax.cond(
        (t % 100000) == 0,
        lambda s: (
            jax.debug.print("Halp-t = {}, gamma = {}, norm(F(z))^2 = {}", t, gamma, jnp.linalg.norm(F(z)) ** 2),
            s
        )[1],
        lambda s: s,
        state
    )

    return state