from pathlib import Path
import jax
import numpy as np
import jax.numpy as jnp
import wandb
from hfm.potentials.n_body_gravity import NBodyGravityPotential
from hfm.simulation.nve_integrator import VelocityVerletIntegrator

from omegaconf import OmegaConf
import os
import hydra

from hfm.potentials.neural_force_field import NeuralForceField
from omegaconf import OmegaConf
from tqdm import tqdm

import os
import hydra

from hfm.simulation.integration_filters import CoupledConservationFilter, EnergyConservationFilterFlashMD, RandomRotationFilter, RemoveAngularMomentumFilter, RemoveDriftFilterFlashMD
from hfm.simulation.mean_flow_integrator import MeanFlowIntegrator
import sys

import pandas as pd
from datetime import datetime, timezone, timedelta


class BatchHelper:
    def __init__(self, batch_size):
        self.batch_size = batch_size
        self.results = []

    def batch_data(self, data_pos, data_mom):
        # reset results
        self.results = []

        for i in tqdm(range(0, data_pos.shape[0], self.batch_size), leave=False, desc="Batches"):
            yield data_pos[i:i+self.batch_size], data_mom[i:i+self.batch_size]

    def log_results(self, result):
        self.results.append(result)

    def cat_results(self):
        return jnp.concatenate(self.results, axis=0)


def load_gt_data(rng, rollout_len, n_test_samples, potential, tstart, datadir):
    # load the GT data (generated using the original script)
    xs_test = np.load(str(Path(datadir) / "gravity_dataset/loc_test_gravity100_initvel1small.npy"))
    vs_test = np.load(str(Path(datadir) / "gravity_dataset/vel_test_gravity100_initvel1small.npy"))

    time_test = jnp.arange(0, rollout_len, step=0.1)

    if n_test_samples == -1:
        sample_indices = jnp.arange(xs_test.shape[0])
    else:
        sample_indices = jax.random.choice(rng, xs_test.shape[0], (n_test_samples,), replace=False)

    mom_start_test = vs_test[sample_indices, tstart*10] * potential.masses
    pos_start_test = xs_test[sample_indices, tstart*10].reshape(n_test_samples, potential.n_balls, 3)
    return time_test, mom_start_test, pos_start_test, vs_test[sample_indices, tstart*10:], xs_test[sample_indices, tstart*10:]


def subsample_or_adapt_time(xs_simulated, rollout_len, time_test, subsample=True):
    # if subsample and len(time_test) < xs_simulated.shape[1]:
    #     # subsample to match the GT data, assume that xs_sim is a multiple of time_test
    #     subsample = int(xs_simulated.shape[1] / len(time_test)) + 1
    #     xs_simulated = xs_simulated[:, ::subsample]
    #     return xs_simulated, time_test

    time_simulated = jnp.linspace(0, rollout_len - 0.1, xs_simulated.shape[1])
    time_simulated = time_simulated[-1:]
    xs_simulated = xs_simulated[:, -1:]
    # if subsample and len(time_test) < xs_simulated.shape[1]:
    #     subsample = int(xs_simulated.shape[1] / len(time_test)) + 1
    #     xs_simulated = xs_simulated[:, ::subsample]
    #     time_simulated = time_simulated[::subsample]

    return xs_simulated, time_simulated


def recover_trajectory(rng, potential, integration_timestep, pos_start_test, mom_start_test, rollout_len, time_test, batch_size, subsample=True):
    # try to recover the test trajectory using our integrator
    integrator = VelocityVerletIntegrator(potential, integration_timestep=integration_timestep)

    helper = BatchHelper(batch_size)
    for pos_start_test_batch, mom_start_test_batch in helper.batch_data(pos_start_test, mom_start_test):
        xs_test_sim_batched, _, _, _ = integrator(pos_start_test_batch, mom_start_test_batch, rollout_len - 0.1, rng=rng)
        helper.log_results(xs_test_sim_batched)

    xs_test_sim = helper.cat_results()
    xs_test_sim = jnp.concat([pos_start_test[:, None], xs_test_sim], axis=1)
    return subsample_or_adapt_time(xs_test_sim, rollout_len, time_test, subsample=subsample)


def load_mf_model(model_path, potential):
    model_workdir = Path(model_path).parent
    cfg_model = OmegaConf.load(os.path.join(model_workdir, ".hydra", "config.yaml"))

    model = hydra.utils.instantiate(
        cfg_model.model
    )

    datamodule = hydra.utils.instantiate(
        cfg_model.data_module
    )

    # override integration_timestep later
    mf_integrator = MeanFlowIntegrator(model=model, params=model_path, data_module=datamodule, integration_timestep=0.0)
    mf_integrator.add_integration_filter(RandomRotationFilter())
    mf_integrator.add_integration_filter(RemoveDriftFilterFlashMD())
    mf_integrator.add_integration_filter(CoupledConservationFilter(
        potential=potential,
    ))

    neural_ff = NeuralForceField(model=model, data_module=datamodule, params=model_path)
    vv_integrator = VelocityVerletIntegrator(neural_ff, integration_timestep=1e-3)

    return mf_integrator, vv_integrator


def get_mf_prediction(rng, integrator, time_array, pos_start_test, mom_start_test, batch_size):
    predictions_x = []
    aux = integrator.init_aux(pos_start_test, mom_start_test)  # auxiliary integration variables, if any

    for time_step in tqdm(time_array, leave=False):
        helper = BatchHelper(batch_size)
        filter_aux = None
        for pos_start_test_batched, mom_start_test_batched in helper.batch_data(pos_start_test, mom_start_test):
            if filter_aux is None:
                filter_aux = integrator.init_filter_aux(pos_start_test_batched, mom_start_test_batched, integrator.masses)
            x_pred_batched, _, _, _, _, _ = integrator.integrate_with_filters(pos_start_test_batched, mom_start_test_batched, time_step, aux, filter_aux, rng)
            helper.log_results(x_pred_batched)
        x_pred = helper.cat_results()
        predictions_x.append(x_pred)

    return jnp.stack(predictions_x, axis=1)


def get_mf_rollout(rng, integrator, integration_timestep, pos_start_test, mom_start_test, rollout_len, time_test, batch_size):
    helper = BatchHelper(batch_size)
    for pos_start_test_batched, mom_start_test_batched in helper.batch_data(pos_start_test, mom_start_test):
        xs_mf_rollout_filter_batched, _, _, _ = integrator(pos_start_test_batched, mom_start_test_batched, integration_time=rollout_len - 0.1, integration_timestep=integration_timestep, rng=rng, intermediate_steps=10)
        helper.log_results(xs_mf_rollout_filter_batched)
    xs_mf_rollout_filter = helper.cat_results()
    xs_mf_rollout_filter = jnp.concat([pos_start_test[:, None], xs_mf_rollout_filter], axis=1)
    return subsample_or_adapt_time(xs_mf_rollout_filter, rollout_len, time_test, subsample=True)


def generate_predictions(rng_sim, models, time_test, long_timesteps, potential, pos_start_test, mom_start_test, rollout_len, predictions, batch_size, run_ff):
    for model_name, model_path in models.items():
        print(f"Generating predictions for model: {model_name}")

        mf_integrator, ff_integrator = load_mf_model(model_path, potential)

        # direct prediction
        rng_ff, rng_mf, rng_rol = jax.random.split(rng_sim, 3)
        mf_prediction = get_mf_prediction(rng_mf, mf_integrator, time_test, pos_start_test, mom_start_test, batch_size)
        append_prediction(f"MF: {model_name}", "direct", mf_prediction, time_test, predictions)

        if run_ff:
            # FF rollout
            mf_prediction_rollout, time_rollout = get_mf_rollout(rng_ff, ff_integrator, 1e-3, pos_start_test, mom_start_test, rollout_len, time_test, batch_size)
            append_prediction(f"FF: {model_name}", 1e-3, mf_prediction_rollout, time_rollout, predictions)

        # different MF rollouts
        for integration_timestep, rng in zip(long_timesteps, jax.random.split(rng_rol, 20)):
            mf_prediction_rollout, time_rollout = get_mf_rollout(rng, mf_integrator, integration_timestep, pos_start_test, mom_start_test, rollout_len, time_test, batch_size)
            append_prediction(f"MF: {model_name}", integration_timestep, mf_prediction_rollout, time_rollout, predictions)


def append_prediction(integrator, timestep, trajectory, time_traj, predictions):
    if isinstance(timestep, float):
        timestep = f"{timestep:.5f}"

    predictions.append({"integrator": integrator, "timestep": timestep, "trajectory": trajectory, "time": time_traj})


def save_predictions(filename, predictions):
    predictions_plot = []
    gt_traj = predictions[0]["trajectory"]
    gt_time = predictions[0]["time"]

    # don't store GT
    for p in predictions[1:]:
        traj = p["trajectory"]
        ttime = p["time"]

        # Find closest index in gt_traj["time"] for each time in p["time"]
        idxs = jnp.array([jnp.argmin(jnp.abs(gt_time - t)) for t in ttime])
        diff = jnp.mean(jnp.square(traj - gt_traj[:, idxs]), axis=(0, 2, 3))

        for i in range(diff.shape[0]):
            predictions_plot.append({"integrator": p["integrator"], "timestep": p["timestep"], "diff": diff[i].item(), "time": ttime[i].item()})

    df = pd.DataFrame(predictions_plot)
    df.to_csv(filename, index=False)
    return predictions_plot


class NBodyGravityMetrics:
    def __init__(self, trainer, datadir, run_ff=False):
        self.trainer = trainer
        self.datadir = datadir
        self.run_ff = run_ff

    def __call__(self):
        seed = 42
        tstart = 3
        rollout_len = 1.1
        n_test_samples = -1
        batch_size = 200
        #long_timesteps = [0.32, 0.16, 0.08, 0.04, 0.02, 0.01, 0.005] 
        r = np.arange(9)
        long_timesteps = (1 / np.pow(2,r)).tolist()

        models = {}
        models["noema"] = str(Path(self.trainer.workdir) / "gravity_mean_flowparams_last.pkl")
        models["ema"] = str(Path(self.trainer.workdir) / "gravity_mean_flowparamsEMA_last.pkl")
        
        rng = jax.random.PRNGKey(seed)
        potential = NBodyGravityPotential(n_balls=100, softening=.1)

        rng_gt, rng_vv, rng_sample, rng_sim = jax.random.split(rng, 4)
        time_test, mom_start_test, pos_start_test, vs_test_traj, xs_test_traj = load_gt_data(rng_sample, rollout_len, n_test_samples, potential, tstart=tstart, datadir=self.datadir)
        
        # Generate the ground truth
        predictions = []    
        # xs_test_sim_gt, time_gt = recover_trajectory(rng_gt, potential, 1e-4, pos_start_test, mom_start_test, rollout_len, time_test, subsample=False)
        # append_prediction("GT", 1e-4, xs_test_sim_gt, time_gt, predictions)
        
        # use numpy as GT
        append_prediction("GT", 1e-3, xs_test_traj, time_test, predictions)

        # Generate model predictions
        print("Generating model predictions..")
        generate_predictions(rng_sim, models, time_test[-1:], long_timesteps, potential, pos_start_test, mom_start_test, rollout_len, predictions, batch_size, run_ff=self.run_ff)

        # print(f"Generating baselines..")
        # for ts, rng_vvx in zip(long_timesteps + [1e-3], jax.random.split(rng_vv, len(long_timesteps) + 1)):
        #     xs_test_sim, time_vv = recover_trajectory(rng_vvx, potential, ts, pos_start_test, mom_start_test, rollout_len, time_test, batch_size, subsample=True)
        #     append_prediction("VV", ts, xs_test_sim, time_vv, predictions)

        print(f"Saving predictions to file..")
        preds = save_predictions(str(Path(self.trainer.workdir) / f"predictions_gravity,tstart={tstart},seed={seed}.csv"), predictions)

        if wandb.run:
            for pred in preds:
                # skip EMA logging
                if "noema" in pred["integrator"]:
                    if "FF" in pred["integrator"]:
                        wandb.run.summary["MSE FF step " + pred["timestep"]] = pred["diff"]
                    elif pred["timestep"] == "direct":
                        wandb.run.summary["MSE direct"] = pred["diff"]
                    else:
                        wandb.run.summary["MSE MF step " + pred["timestep"]] = pred["diff"]
