import time
import os

from absl import logging

import numpy as np
import scipy

import jax
import jax.numpy as jnp
from jax import random, vmap
from jax import vmap, jacrev
from jax.tree_util import tree_map

from flax import jax_utils

import ml_collections
import matplotlib.pyplot as plt

import wandb

from jaxpi.archs import Embedding
from jaxpi.samplers import UniformSampler
from jaxpi.logging import Logger
from jaxpi.utils import save_checkpoint

import models
from utils import get_dataset, sample_points_on_square_boundary


def train_curriculum(config, workdir, model, step_offset, max_steps, Re):
    # Get dataset
    u_ref, v_ref, x_star, y_star, nu = get_dataset(Re)
    U_ref = jnp.sqrt(u_ref**2 + v_ref**2)

    x0 = x_star[0]
    x1 = x_star[-1]

    y0 = y_star[0]
    y1 = y_star[-1]

    # Define domain
    dom = jnp.array([[x0, x1], [y0, y1]])

    # Initialize  residual sampler
    res_sampler = iter(UniformSampler(dom, config.training.batch_size))

    # Initialize evaluator
    evaluator = models.NavierStokesEvaluator(config, model)

    # Initialize logger
    logger = Logger()

    # Update  viscosity
    nu = 1 / Re

    # jit warm up
    print("Waiting for JIT...")
    for step in range(max_steps):
        start_time = time.time()

        batch = next(res_sampler)
        model.state = model.step(model.state, batch, nu)

        # Update weights if necessary
        if config.weighting.scheme in ["grad_norm", "ntk"]:
            if step % config.weighting.update_every_steps == 0:
                model.state = model.update_weights(model.state, batch, nu)

        # Log training metrics, only use host 0 to record results
        if jax.process_index() == 0:
            if step % config.logging.log_every_steps == 0:
                # Get the first replica of the state and batch
                state = jax.tree_map(lambda x: x[0], model.state)
                batch = jax.tree_map(lambda x: x[0], batch)

                log_dict = evaluator(state, batch, x_star, y_star, U_ref, nu)
                wandb.log(log_dict, step + step_offset)

                end_time = time.time()
                # Report training metrics
                logger.log_iter(step, start_time, end_time, log_dict)

        # Save checkpoint
        if config.saving.save_every_steps is not None:
            if (step + 1) % config.saving.save_every_steps == 0 or (
                step + 1
            ) == config.training.max_steps:
                ckpt_path = os.path.join(os.getcwd(), config.wandb.name, "ckpt", "Re{}".format(Re))
                save_checkpoint(model.state, ckpt_path, keep=config.saving.num_keep_ckpts)

        # update old params to current params
        state = jax.tree.map(lambda x: x[0], model.state)
        model.old_state = state

    # Get step offset
    step_offset = step + step_offset

    return model, step_offset


def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    # Initialize W&B
    wandb_config = config.wandb
    wandb.init(project=wandb_config.project, name=wandb_config.name)

    if config.use_pi_init:
        print("Use physics-informed initialization...")
    #
    #     # Load reference data at Re = 100
    #     data = jnp.load("data/ldc_Re100.npy", allow_pickle=True).item()
    #     u_star = jnp.array(data["u"])
    #     v_star = jnp.array(data["v"])
    #     p_star = jnp.array(data["p"])
    #
    #     x_star = jnp.array(data["x"])
    #     y_star = jnp.array(data["y"])
    #
        model = models.NavierStokes2D(config)
        state = jax.device_get(tree_map(lambda x: x[0], model.state))
        params = state.params

        feat_matrix, _ = vmap(state.apply_fn, (None, 0))(params, model.x_bc1)

        u_coeffs, u_res, _, _ = jnp.linalg.lstsq(feat_matrix, model.u_bc, rcond=None)
        v_coeffs, v_res, _, _ = jnp.linalg.lstsq(feat_matrix, model.v_bc, rcond=None)

        print("u_res = {}".format(u_res))
        print("v_res = {}".format(v_res))
        # print("p_res = {}".format(p_res))

        subkeys = random.split(random.PRNGKey(config.seed+1), 2)

        v_coeffs = 1e-2 * random.normal(subkeys[0], shape=u_coeffs.shape)
        p_coeffs = 1e-3 * random.normal(subkeys[1], shape=u_coeffs.shape)

    #     xx, yy = jnp.meshgrid(x_star, y_star, indexing='ij')
    #     inputs = jnp.hstack([xx.flatten()[:, None], yy.flatten()[:, None]])
    #     feat_matrix, _ = vmap(state.apply_fn, (None, 0))(params, inputs)

        coeffs = jnp.stack([u_coeffs, v_coeffs, p_coeffs]).T

        config.arch.pi_init = coeffs

        del model, state, params

    # Initialize model
    model = models.NavierStokes2D(config)

    # Curriculum training
    step_offset = 0

    assert len(config.training.max_steps) == len(config.training.Re)
    num_Re = len(config.training.Re)

    for idx in range(num_Re):
        # Set Re and maximum number of training steps
        Re = config.training.Re[idx]
        max_steps = config.training.max_steps[idx]
        print("Training for Re = {}".format(Re))
        model, step_offset = train_curriculum(
            config, workdir, model, step_offset, max_steps, Re
        )



    return model
