#%%
import os
from datetime import datetime
from pathlib import Path
from typing import Sequence
import matplotlib.pyplot as plt
from datargs import argsclass, parse
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import jax.nn as jnn

plt.rcParams['figure.facecolor'] = 'white'

import jax
jax.config.update("jax_enable_x64", True)


@argsclass
class Args:
    T: int = 10_000
    inner: int = 1
    sigma: float = None
    delta: float = None
    d: int = 1
    d2: int = None
    batch_size: int = 1
    gamma: float = 1.0
    alpha: float = 1/18
    seed: int = None
    method: str = "RAPP"
    projection: str = None
    radius: int = 1

    problem: str = "quadratic"
    L: float = 1.0
    rho: float = -0.4
    init: Sequence[float] = (0.5, 0.5)
    offset: float = 0.0

    plot_field: bool = False

    name: str = None

    def setup(self):
        if self.seed is None:
            random_data = os.urandom(4) 
            self.seed = int.from_bytes(random_data, byteorder="big")

        if self.d2 is None:
            self.d2 = self.d

        now = datetime.now()
        name = now.strftime("%Y-%m-%d_%Hh%Mm%Ss")
        if self.name is not None:
            name = f"{self.name}({name})"
        
        self.dir = os.path.join("output", name)
        Path(self.dir).mkdir(parents=True, exist_ok=True)
        with open(os.path.join(self.dir, "args.txt"), "w") as args_file:
            args_file.write(str(self))
        return self

    def path(self, filename):
        return os.path.join(self.dir, filename)


def init_history(keys=[], length=0):
    history = {}
    for key in keys:
        history[key] = jnp.arange(length, dtype="float32")
    return history

def update_history(history, t, values: dict):
    for k,v in values.items():
        if k in history:
            history[k] = history[k].at[t].set(v)
    return history

args = Args()

#%%

if __name__ == '__main__':
    args = parse(Args).setup()

#%%

# Setup
args.setup()

def create_problem(args):
    if args.problem == "bilinear":
        def L(x,y):
            return (x-args.offset).transpose().dot(y-args.offset)
    elif args.problem == "quadratic":
        a = jnp.sqrt(args.L**2 - args.L**4 * args.rho ** 2)
        b = args.L**2 * args.rho
        def L(x,y):
            x = x-args.offset
            y = y-args.offset
            return a * x.transpose().dot(y) + b/2 * jnp.sum(x**2, axis=-1) - b/2 * jnp.sum(y**2, axis=-1)

    return L

def make_F(L):
    Fx = jit(vmap(grad(L, argnums=0)))
    Fy = jit(vmap(grad(L, argnums=1)))
    def F(z):
        x,y = z[:, :args.d],z[:, args.d:]
        return jnp.concatenate([Fx(x,y), -Fy(x,y)], axis=-1)
    return F


def split(z):
    return z[:, :args.d], z[:, args.d:]

def combine(x,y):
    return jnp.concatenate([x,y], axis=-1)


L = create_problem(args)
F = make_F(L)

# Projection
if args.projection == 'linf':
    def proj(z):
        return jnp.clip(z,  -args.radius, args.radius)
elif args.projection is None:
    proj = lambda z: z

globalkey = jax.random.PRNGKey(args.seed)

# Initialize
if args.d == 1:
    z0 = jnp.array([args.init])
else:
    globalkey, subkey = jax.random.split(globalkey)
    z0 = jax.random.uniform(subkey, (args.batch_size, args.d+args.d2))
z = z0
zprev = z
zbar = z


#%%

history = init_history(keys=[
    '|Fbar|', 
    '|F|', 
    'refinements',
], length=args.T)

if args.delta is None:
    delta = args.rho
else:
    delta = args.delta

if args.sigma is None:
    sigma = (1 + delta/args.gamma) * 0.999
else:
    sigma = args.sigma


def loop_body(t, state):
    history, z, zbar, zprev, globalkey = state 

    # Log prior to step to ensure first iterate is logged
    entry = {
        '|Fbar|': jnp.linalg.norm(F(zbar)), 
        '|F|': jnp.linalg.norm(F(z)),
        #'step': t,
    }

    refinements = 0

    if "RAPP" == args.method:
        w = z
        w = jax.lax.fori_loop(0, args.inner, lambda i, w: z - args.gamma * F(w), w)
        z = (1-args.alpha) * z + args.alpha * w

    elif "PD-RAPP" == args.method:
        w = z
        for i in range(args.inner):
            _,y = split(w)
            xnext,_ = split(z - args.gamma * F(w))
            _,ynext = split(z - args.gamma * F(combine(xnext, y)))
            w = combine(xnext, ynext)
        z = (1-args.alpha) * z + args.alpha * w

    elif "HPE" == args.method:
        zbar = z
        vbar = args.gamma * F(zbar)
        
        def cond_fun(val):
            zbar, vbar, refinements = val
            eps = z - zbar - vbar
            return -jnp.vdot(eps, vbar) > sigma * jnp.vdot(vbar, vbar)

        def body_fun(val):
            zbar, vbar, refinements = val
            zbar = z - vbar
            vbar = args.gamma * F(zbar)
            return zbar, vbar, refinements + 1

        zbar, vbar, refinements = jax.lax.while_loop(cond_fun, body_fun, (zbar, vbar, 0))

        vbar_norm_sq = jnp.vdot(vbar, vbar)
        alpha = (jnp.vdot(vbar, z - zbar) + delta / args.gamma * vbar_norm_sq) / vbar_norm_sq
        z = z - alpha * vbar

    else:
        raise ValueError("Method not supported")
    
    # Only record the logging after so that #(refinements) can be recorded
    entry['refinements'] = refinements
    if args.d == 1:
        entry['x'] = z[:,0]
        entry['y'] = z[:,1]
    history = update_history(history, t, entry)

    # if (t % 1000 == 0):
    #     print("t=", t)
    
    return (history, z, zbar, zprev, globalkey)


state = (history, z, zbar, zprev, globalkey)
init_state = (history, z, zbar, zprev, globalkey)
state = jax.lax.fori_loop(1, args.T, loop_body, init_state)
(history, z, zbar, zprev, globalkey) = state


for k in history.keys():
    fig, ax = plt.subplots(1, 1)
    ax.plot(history[k])
    ax.set_yscale('log')
    ax.set_xscale('log')
    ax.set_title(k)
    fig.tight_layout()
    fig.savefig(args.path(f"{k}.png"))
    plt.close(fig)
    if k in ["|F|", "|Fbar|", "|H-Hbar|"]:
        jnp.save(args.path(f"{k}.npy"), history[k])

# print("z:", z)
print(f'||Fz||: {history["|F|"][-1]} (refinements {round(history["refinements"].mean())})')

def plot_vectorfield_with_trajectory(args, history, F):
    assert args.d == 1
    assert args.batch_size == 1
    N = 10
    M = 10
    bounds = [[-1,1], [-1,1]]
    x,y = jnp.meshgrid(
        jnp.linspace(bounds[0][0],bounds[0][1], N),
        jnp.linspace(bounds[1][0],bounds[1][1], M))
    x = x.flatten()[:,jnp.newaxis]
    y = y.flatten()[:,jnp.newaxis]
    Z = jnp.concatenate((x,y), axis=-1)
    FZ = F(Z)
    u, v = FZ[:, 0], FZ[:, 1]
    x = x.reshape(N,M)
    y = y.reshape(N,M)
    u = u.reshape(N,M)
    v = v.reshape(N,M)

    fig, ax = plt.subplots(1, 1)

    # Vectorfield
    # ax.streamplot(x,y,u,v, color="grey")

    # Trajectory
    ax.plot(history['x'], history['y'], color="red", label=args.method)
    ax.scatter(history['x'][0], history['y'][0], color="black")
    #ax.scatter(history['x'], history['y'], color="red")
    
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(bounds[0][0],bounds[0][1])
    ax.set_ylim(bounds[1][0],bounds[1][1])
    ax.legend(loc='lower left')
    return fig, ax

if args.plot_field:
    if args.d == 1:
        fig, ax = plot_vectorfield_with_trajectory(args, history, F)
        fig.tight_layout()
        fig.savefig(args.path(f"vectorfield.png"))
