# main.py
import jax
import jax.numpy as jnp
import jax.random as jr

import argparse
import os
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
import time
import random
import pickle
import pandas as pd
import wandb
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from jax import grad, vmap, jacrev, jacfwd
from jax import jit
from jax import value_and_grad
import optax


def init_path(experiment_name, exp_name, args, subdirs=None):
    if subdirs is None:
        subdirs = []
    os.makedirs(experiment_name, exist_ok=True)

    exp_path = os.path.join(experiment_name, exp_name)
    os.makedirs(exp_path, exist_ok=True)

    # keep existing side-effects (args.yaml + optional subdirs) with minimal change
    outfile = os.path.join(exp_path, 'log.txt')
    with open(outfile, 'w') as f:
        f.write(f'Experiment {exp_name}\n')
    with open(os.path.join(exp_path, 'args.yaml'), 'w') as f:
        yaml.dump(vars(args), f)

    for subdir in subdirs:
        os.makedirs(os.path.join(exp_path, subdir), exist_ok=True)

    return exp_path, outfile


parser = argparse.ArgumentParser(description='Train pinn')

# pde
parser.add_argument('--exp_name', type=str, default='tmp',
                    help='exp name')

# model parameters & optimization
parser.add_argument('--batch_size', type=int, default=4096,
                    help='Number of samples in each minibatch')
parser.add_argument('--n_boundary', type=int, default=1024,
                    help='n_boundary')
parser.add_argument('--epochs', type=int, default=200,
                    help='Number of training epochs')
parser.add_argument('--lr', type=float, default=1e-3,
                    help='Learning rate')
parser.add_argument('--scheduler_step', type=int, default=1000,
                    help='scheduler_step')
parser.add_argument('--n_iter', type=int, default=100,
                    help='f')
parser.add_argument('--t_0', type=float, default=0.1,
                    help='t_0')
parser.add_argument('--t_max', type=float, default=5,
                    help='t_max')

parser.add_argument('--n_layers', type=int, default=4,
                    help='n_layers')
parser.add_argument('--scheduler', type=str, default='step',
                    help='scheduler')
parser.add_argument('--lr_min', type=float, default=1e-6,
                    help='lr_min')
parser.add_argument('--n_hidden', type=int, default=256,
                    help='n_hidden')
parser.add_argument('--n_fourier', type=int, default=256,
                    help='n_fourier')
parser.add_argument('--sigma_fourier', type=float, default=1.,
                    help='sigma_fourier')
parser.add_argument('--t_bundle', type=int, default=32,
                    help='t_bundle')
parser.add_argument('--t_epsilon', type=float, default=0.01,
                    help='t_epsilon')
parser.add_argument('--exp_bundle', type=bool, default=False, action=argparse.BooleanOptionalAction,
                    help='exp_bundle')

parser.add_argument('--seed', type=int, default=None,
                    help='seed')

args = parser.parse_args()

# seed
if args.seed is None:
    args.seed = hash(args.exp_name + str(time.time())) % 2**32

key = jr.PRNGKey(args.seed)

eqn_name = 'FP_S2_jax'

exp_path, _ = init_path(f'{eqn_name}/data-pinn', args.exp_name, args, subdirs=['checkpoints'])

run = wandb.init(
    project=eqn_name,
    name=args.exp_name,
    config=vars(args),
    dir=exp_path,
)
run.summary["experiment_start"] = str(datetime.now())

# from PDEs import FokkerPlanck
# eqn = FokkerPlanck()
exp_name = args.exp_name

n_indep = 1
n_dep = 4

objectives = ['pde', 'boundary']
weight_dict = {obj: 1. for obj in objectives}

from itertools import chain


def sample_sphere(key, n_sample):
    subkey1, subkey2 = jr.split(key)
    z = jr.uniform(subkey1, (n_sample,), minval=-1.0, maxval=1.0)
    phi = jr.uniform(subkey2, (n_sample,)) * 2 * jnp.pi
    rho = jnp.sqrt(1 - z**2)
    return jnp.stack([jnp.sin(phi) * rho, jnp.cos(phi) * rho, z], axis=1)


import sympy as sp

r = sp.symbols('r')
D = sp.sin(r) / r
DD = (1 / D * sp.diff(D, r)).cancel().simplify()


def lap(u):
    return -sp.diff(u, r, 2) - (1 / r + DD) * sp.diff(u, r)


s = sp.symbols('s')


def iterate(u, i):
    integrand = (sp.sqrt(D) * lap(u)).cancel().subs(r, s) * s ** (i - 1)
    integral = sp.integrate(integrand.cancel(), s, heurisch=True).cancel().simplify()
    integral = integral.subs(s, r)
    return (-1 * (r ** (-i)) / sp.sqrt(D) * integral).cancel().simplify()


s = sp.symbols('s')


def iterate_approx(u, i):
    integrand = (sp.sqrt(D) * lap(u)).cancel().subs(r, s) * s ** (i - 1)
    integrand = integrand.series(s, 0, 10).removeO()
    integral = sp.integrate(integrand.cancel(), s).cancel().simplify()
    integral = integral.subs(s, r)
    return (-1 * (r ** (-i)) / sp.sqrt(D) * integral).cancel().simplify().series(r, 0, 10).removeO()


u_0 = sp.Piecewise((1 / sp.sqrt(D), r != 0), (1, r == 0))
u_1 = iterate(u_0, 1)
u_2 = iterate_approx(u_1, 2)
u_3 = iterate_approx(u_2, 3)

t = sp.symbols('t')
G = (4 * sp.pi * t) ** (-1) * sp.exp(-r**2 / (4 * t))
logG = -sp.log(4 * sp.pi * t) - r**2 / (4 * t)

S3 = G * (u_0 + u_1 * t + u_2 * t**2 + u_3 * t**3)
logS3 = logG + sp.log(u_0 + u_1 * t + u_2 * t**2 + u_3 * t**3)

pi = jnp.pi
epochs = args.epochs
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs
n_boundary = args.n_boundary
n_interior = batch_size - n_boundary

t_0 = args.t_0
f_logS3 = sp.lambdify(r, logS3.subs(t, t_0).simplify(), modules="jax")

t_max = args.t_max

n_iter = args.n_iter
alpha = 0.9

t_epsilon = args.t_epsilon
t_bundle = args.t_bundle

if args.exp_bundle:
    t_bundlestep = jnp.exp(jnp.linspace(jnp.log(t_0) + 1e-3, jnp.log(t_max), t_bundle))
else:
    t_bundlestep = jnp.linspace(t_0 + 1e-3, t_max, t_bundle)

t_weight = jnp.ones(t_bundle)

from archs import ModifiedMlp

model = ModifiedMlp(
    num_layers=args.n_layers - 1,
    hidden_dim=args.n_hidden,
    out_dim=1,
    fourier_emb={
        'embed_scale': args.sigma_fourier,
        'embed_dim': args.n_fourier
    },
    reparam={
        'type': 'weight_fact',
        'mean': 1.,
        'stddev': 0.1,
    }
)


def f_0(variables, x, t):
    xt = jnp.concatenate([x, t], axis=-1)
    u_0_ = model.apply(variables, xt)
    return u_0_, u_0_


def f_t(variables, x, t):
    u_t_, u_0_ = jacrev(f_0, argnums=2, has_aux=True)(variables, x, t)
    return u_t_, (u_t_, u_0_)


def f_x(variables, x, t):
    u_x_, u_0_ = jacrev(f_0, argnums=1, has_aux=True)(variables, x, t)
    return u_x_, (u_x_, u_0_)


def f_xx(variables, x, t):
    u_xx_, (u_x_, u_0_) = jacrev(f_x, argnums=1, has_aux=True)(variables, x, t)
    return u_xx_, (u_xx_, u_x_, u_0_)


# Function to compute the PDE equation value
def equation(x, t, u_0_, u_t_, u_x_, u_xx_):
    v1 = jnp.trace(u_xx_, axis1=-1, axis2=-2)  # Equivalent to diagonal().sum()
    v2 = jnp.sum(u_x_ ** 2, axis=-1)
    v3 = jnp.einsum('bi,bij,bj->b', x, u_xx_, x)
    v4 = jnp.einsum('bi,bi->b', x, u_x_) ** 2
    v5 = 2 * jnp.einsum('bi,bi->b', x, u_x_)
    eqn_value = u_t_ - (v1 + v2 - v3 - v4 - v5)
    return eqn_value


# Boundary function
def boundary(x):
    r_ = jnp.clip(jnp.arccos(x[:, 2]), 1e-4, np.pi - 1e-4)  # Avoid NaN
    return f_logS3(r_)  # JAX-compatible via lambdify(modules="jax")


# Generate boundary data
def boundary_data_gen(key, n_data, t_0_):
    subkey = jr.split(key, 1)[0]
    t_ = jnp.ones((n_data, 1)) * t_0_
    x_ = sample_sphere(subkey, n_data)
    return x_, t_


# Generate interior data
def interior_data_gen(key, n_data, t_0_, t_max_):
    subkey1, subkey2 = jr.split(key)
    t_ = jr.uniform(subkey1, (n_data, 1)) * (t_max_ - t_0_) + t_0_
    x_ = sample_sphere(subkey2, n_data)
    return x_, t_


def bucketize(t_, boundaries):
    return jnp.searchsorted(boundaries, t_, side='left')


def segment_mean(data, segment_ids, num_segments):
    sums = jax.ops.segment_sum(data, segment_ids, num_segments=num_segments)
    counts = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, num_segments=num_segments)
    return sums / (counts + 1.0e-10)


def grad_norm(grad_):
    flattened_grad = jax.flatten_util.ravel_pytree(grad_)[0]
    return jnp.linalg.norm(flattened_grad)


@jit
def compute_pde_loss(variables, x, t):
    _, (u_t_, u_0_) = vmap(f_t, in_axes=(None, 0, 0))(variables, x, t)
    _, (u_xx_, u_x_, u_0_) = vmap(f_xx, in_axes=(None, 0, 0))(variables, x, t)

    t_ = t[:, 0]
    u_0_ = u_0_[:, 0]
    u_t_ = u_t_[:, 0, 0]
    u_x_ = u_x_[:, 0]
    u_xx_ = u_xx_[:, 0]

    # PDE residual
    pde_value = equation(x, t_, u_0_, u_t_, u_x_, u_xx_)

    # Bucketized loss
    t_bucket = bucketize(t_, t_bundlestep)
    num_buckets = t_bundlestep.shape[0]
    loss_bucket = segment_mean(jnp.abs(pde_value), t_bucket, num_buckets)
    loss_pde = jnp.mean(t_weight * loss_bucket)
    loss_unweighted = jnp.mean(jnp.abs(pde_value))

    return loss_pde, loss_unweighted


@jit
def compute_boundary_loss(variables, x, t):
    _, (u_0_) = vmap(f_0, in_axes=(None, 0, 0))(variables, x, t)

    u_0_ = u_0_[:, 0]

    # Boundary loss
    b = boundary(x[:n_boundary])
    pred_b = u_0_
    loss_b = jnp.mean(jnp.abs(pred_b - b)) + jnp.mean(jnp.abs(jnp.exp(pred_b) - jnp.exp(b)))

    return loss_b


ndim = 3

key, subkey = jax.random.split(key)
x = jnp.empty((1, ndim + 1))
variables = model.init(subkey, x)

mask = jax.tree_util.tree_map(lambda x_: True, variables)
mask['params']['FourierEmbs_0']['kernel'] = False


def mask_fn(grad_):
    return jax.tree_util.tree_map(lambda x_, b: x_ * float(b), grad_, mask)


def stepwise_lr_schedule(lr_: float, lr_min_: float, step_interval: int, decay_factor: float):
    def schedule_fn(step):
        num_decays = step // step_interval
        decayed_lr = lr_ * (decay_factor ** num_decays)
        return jnp.maximum(decayed_lr, lr_min_)
    return schedule_fn


if args.scheduler == 'step':
    lr_schedule = stepwise_lr_schedule(
        lr_=args.lr,
        lr_min_=args.lr_min,
        step_interval=args.scheduler_step,
        decay_factor=0.9
    )
elif args.scheduler == 'cosine':
    lr_schedule = optax.cosine_decay_schedule(
        init_value=args.lr,
        decay_steps=args.epochs * args.n_iter,
        alpha=args.lr_min / args.lr
    )

optimizer = optax.adam(lr_schedule)
opt_state = optimizer.init(variables)
min_loss = 1e10
weight_dict = {obj: 1. for obj in objectives}

loss_func_pde = jit(value_and_grad(compute_pde_loss, argnums=0, has_aux=True))
loss_func_boundary = jit(value_and_grad(compute_boundary_loss, argnums=0))


def update_weight(objectives_, weight_dict_, gnorm_dict_, alpha_):
    gnorm_sum = np.sum(list(gnorm_dict_.values()))
    for obj in objectives_:
        new_weight = gnorm_sum / gnorm_dict_[obj]
        weight_dict_[obj] = alpha_ * weight_dict_[obj] + (1 - alpha_) * new_weight
    return weight_dict_


for epoch in tqdm(range(epochs)):
    # epoch accumulators
    acc_loss_pde = 0.0
    acc_loss_boundary = 0.0
    acc_unweighted = 0.0
    acc_gnorm_pde = 0.0
    acc_gnorm_boundary = 0.0

    for _ in range(n_iter):
        key_b, key_i = jr.split(key)
        x_b, t_b = boundary_data_gen(key_b, n_boundary, t_0)
        x_i, t_i = interior_data_gen(key_i, n_interior, t_0, t_max)

        # Concatenate boundary and interior
        x = jnp.concatenate([x_b, x_i], axis=0)
        t = jnp.concatenate([t_b, t_i], axis=0)

        (loss_pde, loss_unweighted), grad_pde = loss_func_pde(variables, x, t)
        loss_boundary, grad_boundary = loss_func_boundary(variables, x[:n_boundary], t[:n_boundary])

        grad_pde, grad_boundary = mask_fn(grad_pde), mask_fn(grad_boundary)

        weight_pde, weight_boundary = weight_dict['pde'], weight_dict['boundary']
        grad_all = jax.tree_util.tree_map(
            lambda gx, gy: weight_pde * gx + weight_boundary * gy,
            grad_pde, grad_boundary
        )

        updates, opt_state = optimizer.update(grad_all, opt_state)
        variables = optax.apply_updates(variables, updates)

        gnorm_pde = float(grad_norm(grad_pde))
        gnorm_boundary = float(grad_norm(grad_boundary))

        acc_loss_pde += float(loss_pde)
        acc_loss_boundary += float(loss_boundary)
        acc_unweighted += float(loss_unweighted)
        acc_gnorm_pde += gnorm_pde
        acc_gnorm_boundary += gnorm_boundary

    metric_dict = {
        'loss_pde': acc_loss_pde / n_iter,
        'loss_boundary': acc_loss_boundary / n_iter,
        'gnorm_pde': acc_gnorm_pde / n_iter,
        'gnorm_boundary': acc_gnorm_boundary / n_iter,
        'weight_pde': float(weight_dict['pde']),
        'weight_boundary': float(weight_dict['boundary']),
        'unweighted_pde': acc_unweighted / n_iter,
    }

    wandb.log(metric_dict, step=epoch)

    gnorm_dict = {
        'pde': metric_dict['gnorm_pde'],
        'boundary': metric_dict['gnorm_boundary']
    }
    weight_dict = update_weight(objectives, weight_dict, gnorm_dict, alpha)

    if metric_dict['unweighted_pde'] < min_loss:
        min_loss = metric_dict['unweighted_pde']
        jnp.save(os.path.join(exp_path, 'min_loss.npy'), variables)
        run.summary["best_unweighted_pde"] = float(min_loss)

    jnp.save(os.path.join(exp_path, 'last_epoch.npy'), variables)

wandb.finish()
