from argparse import Namespace
from functools import partial

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct
from flax.training.train_state import TrainState
import optax
from distrax import Normal
from orbax.checkpoint import PyTreeCheckpointer

from opelab.core.baselines.pgd.util import *
from opelab.core.baselines.pgd.environments import DatasetRolloutGenerator, OfflineRolloutGenerator
from opelab.core.baselines.pgd.rl import DETERMINISTIC_ACTORS


# diffusion
class TimeEmbedding(nn.Module):
    features: int = 64

    @nn.compact
    def __call__(self, t):
        # Transformer sinusoidal positional encoding
        half_dim = self.features // 8
        emb = jnp.log(10000) / (half_dim - 1)
        emb = jnp.exp(jnp.arange(half_dim) * -emb)
        emb = jnp.outer(t, emb)
        emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
        emb = nn.Dense(self.features)(emb)
        # Swish activation
        emb = emb * nn.sigmoid(emb)
        emb = nn.Dense(self.features)(emb)
        return emb


class Encoder(nn.Module):
    features: int = 64
    n_blocks: int = 5
    n_groups: int = 8

    @nn.compact
    def __call__(self, x, t):
        zs = []
        t = t * nn.sigmoid(t)
        for i in range(self.n_blocks):
            x = nn.Conv(self.features * (2**i), kernel_size=(3,))(x)
            t_emb = nn.Dense(self.features * (2**i))(t)
            x += t_emb
            x = nn.GroupNorm(num_groups=self.n_groups)(x)
            x = nn.relu(x)
            x = nn.Conv(self.features * (2**i), kernel_size=(3,))(x)
            x = nn.GroupNorm(num_groups=self.n_groups)(x)
            x = nn.relu(x)
            zs.append(x)
            x = nn.max_pool(x, window_shape=(2,), strides=(2,))
        return zs


class Decoder(nn.Module):
    out_features: int
    features: int = 64
    n_blocks: int = 5
    n_groups: int = 8

    def _upsample(self, x, target_length):
        # Deconvolution currently just duplicates elements
        # TODO: Test alternative upsampling methods
        return jax.image.resize(
            x, shape=(*x.shape[:-2], target_length, x.shape[-1]), method="nearest"
        )

    @nn.compact
    def __call__(self, zs, t):
        x = zs[-1]
        t = t * nn.sigmoid(t)
        for i in range(self.n_blocks - 2, -1, -1):
            z = zs[i]
            x = self._upsample(x, z.shape[-2])
            x = nn.Conv(self.features * (2**i), kernel_size=(2,))(x)
            x += nn.Dense(self.features * (2**i))(t)
            x = nn.GroupNorm(num_groups=self.n_groups)(x)
            x = nn.relu(x)
            x = jnp.concatenate([x, z], axis=-1)
            x = nn.Conv(self.features * (2**i), kernel_size=(3,))(x)
            x = nn.GroupNorm(num_groups=self.n_groups)(x)
            x = nn.relu(x)
            x = nn.Conv(self.features * (2**i), kernel_size=(3,))(x)
            x = nn.GroupNorm(num_groups=self.n_groups)(x)
            x = nn.relu(x)
        x = nn.Conv(self.out_features, kernel_size=(1,))(x)
        return x


class UNet(nn.Module):
    features: int = 64
    n_blocks: int = 5

    @nn.compact
    def __call__(self, x, t):
        t = TimeEmbedding(features=self.features)(t)
        zs = Encoder(features=self.features, n_blocks=self.n_blocks)(x, t)
        y = Decoder(
            out_features=x.shape[-1], features=self.features, n_blocks=self.n_blocks
        )(zs, t)
        return y


# edm
@struct.dataclass
class DenoiserHyperparams:
    p_mean: float = -1.2  # Mean of log-normal noise distribution
    p_std: float = 1.2  # Standard deviation of log-normal noise distribution
    sigma_data: float = 1.0  # Standard deviation of data distribution
    sigma_min: float = 0.002  # Minimum noise level
    sigma_max: int = 80  # Maximum noise level
    rho: float = 7.0  # Sampling schedule
    edm_first_order: bool = False  # Disable second order Heun integration
    diffusion_timesteps: int = 200  # Number of diffusion timesteps for sampling
    # Stochastic sampling coefficients
    s_tmin: float = 0.05
    s_tmax: float = 50.0
    s_churn: float = 80.0
    s_noise: float = 1.003


# Derived preconditioning params - EDM Table 1
def c_skip(sigma, sigma_data):
    return (sigma_data**2) / (sigma**2 + sigma_data**2)


def c_out(sigma, sigma_data):
    return sigma * sigma_data * ((sigma_data**2 + sigma**2) ** -0.5)


def c_in(sigma, sigma_data):
    return (sigma**2 + sigma_data**2) ** -0.5


def c_noise(sigma):
    return jnp.log(sigma) * 0.25


def train_step(
    rng,
    batch,
    denoiser_state,
    denoiser_hyperparams,
):
    """
    Params:
        data: (batch, seq_len, obs_dim + action_dim + 2)
        ts: (batch,)
    """

    def loss_weight(sigma):
        return (sigma**2 + denoiser_hyperparams.sigma_data**2) * (
            sigma * denoiser_hyperparams.sigma_data
        ) ** -2

    def seq_loss(denoiser_params, rng, seq):
        # Implements EDM from https://openreview.net/pdf?id=k7FuTOWMOc7
        rng, _rng = jax.random.split(rng)
        sigma = jnp.exp(
            (
                denoiser_hyperparams.p_mean
                + denoiser_hyperparams.p_std * jax.random.normal(_rng)
            )
        )

        rng, _rng = jax.random.split(rng)
        noise = jax.random.normal(_rng, shape=seq.shape)
        noised_seq = seq + sigma * noise  # alphas are 1. in the paper
        noise_pred = denoiser_state.apply_fn(
            denoiser_params,
            c_in(sigma, denoiser_hyperparams.sigma_data) * noised_seq,
            c_noise(sigma),
        )
        denoised_pred = (
            c_skip(sigma, denoiser_hyperparams.sigma_data) * noised_seq
            + c_out(sigma, denoiser_hyperparams.sigma_data) * noise_pred
        )
        return jnp.square(denoised_pred - seq) * loss_weight(sigma)

    def batch_loss(denoiser_params):
        _rng = jax.random.split(rng, batch.shape[0])
        return jnp.mean(
            jax.vmap(seq_loss, in_axes=(None, 0, 0))(denoiser_params, _rng, batch)
        )

    loss_val, grad = jax.value_and_grad(batch_loss)(denoiser_state.params)
    denoiser_state = denoiser_state.apply_gradients(grads=grad)
    return denoiser_state, loss_val


def sample_trajectory(
    rng,
    denoiser_state,
    seq_len,
    obs_dim,
    action_dim,
    denoiser_norm_stats,
    denoiser_hyperparams,
    policy_guidance_coeff=0.0,
    policy_guidance_delay_steps=0,
    policy_guidance_cosine_coeff=0.0,
    normalize_action_guidance=True,
    denoised_guidance=False,
    det_guidance=False,
    agent_apply_fn=None,
    agent_params=None,
):
    # --- Compute noise schedule ---
    def _get_noise_schedule(num_diffusion_timesteps):
        inv_rho = 1 / denoiser_hyperparams.rho
        sigmas = (
            denoiser_hyperparams.sigma_max**inv_rho
            + (jnp.arange(num_diffusion_timesteps + 1) / (num_diffusion_timesteps - 1))
            * (
                denoiser_hyperparams.sigma_min**inv_rho
                - denoiser_hyperparams.sigma_max**inv_rho
            )
        ) ** denoiser_hyperparams.rho
        return sigmas.at[-1].set(0.0)  # last step has sigma value of 0.

    sigmas = _get_noise_schedule(denoiser_hyperparams.diffusion_timesteps)
    gammas = jnp.where(
        (sigmas >= denoiser_hyperparams.s_tmin)
        & (sigmas <= denoiser_hyperparams.s_tmax),
        jnp.minimum(
            denoiser_hyperparams.s_churn / denoiser_hyperparams.diffusion_timesteps,
            jnp.sqrt(2) - 1,
        ),
        0.0,
    )

    # --- Sample random noise trajecory ---
    rng, _rng = jax.random.split(rng)
    # Add 2 dimensions for reward and done
    init_noise = jax.random.normal(_rng, (seq_len, obs_dim + action_dim + 2))
    init_noise *= sigmas[0]

    # --- Construct guidance function ---
    do_apply_guidance = (
        agent_apply_fn is not None
        and agent_params is not None
        and policy_guidance_coeff != 0.0
    )

    def _compute_action_guidance(traj):
        # --- Unnormalize observation ---
        obs = traj[:, :obs_dim]
        obs = unnormalise_traj(obs, denoiser_norm_stats["obs"])

        # --- Compute guidance from policy ---
        pi = agent_apply_fn(agent_params, obs)
        if det_guidance:
            # Apply guidance to unit Gaussian around deterministic action
            agent_action = pi.sample(seed=jax.random.PRNGKey(0))
            pi = Normal(agent_action, 1.0)

        def _transformed_action_log_prob(action):
            action = unnormalise_traj(action, denoiser_norm_stats["action"])
            action = jnp.tanh(action)
            return pi.log_prob(action).sum()

        action = traj[:, obs_dim : obs_dim + action_dim]
        action_guidance = jax.grad(_transformed_action_log_prob)(action)

        # --- Normalize and return guidance ---
        if normalize_action_guidance:
            action_guidance = action_guidance / jnp.linalg.norm(action_guidance) + 1e-8
        return action_guidance

    def denoise_step(runner_state, step_coeffs):
        rng, noised_traj, step_idx = runner_state
        sigma, next_sigma, gamma = step_coeffs

        if do_apply_guidance:
            # --- Compute guidance coefficient ---
            n_steps = denoiser_hyperparams.diffusion_timesteps
            lambd = 1.0 - (step_idx / n_steps)
            cosine_adjustment = jnp.sin(jnp.pi * ((step_idx + 1) / n_steps))
            lambd += policy_guidance_cosine_coeff * cosine_adjustment
            do_apply_guidance_this_step = jnp.logical_and(
                step_idx >= policy_guidance_delay_steps,
                step_idx < n_steps - 1,
            )
            lambd = jnp.where(
                do_apply_guidance_this_step, policy_guidance_coeff * lambd, 0.0
            )

            # --- Compute denoised trajectory for guidance ---
            guidance_traj = noised_traj
            if denoised_guidance:
                noise_pred = denoiser_state.apply_fn(
                    denoiser_state.params,
                    c_in(sigma, denoiser_hyperparams.sigma_data) * noised_traj,
                    c_noise(sigma),
                )
                guidance_traj = (
                    c_skip(sigma, denoiser_hyperparams.sigma_data) * noised_traj
                    + c_out(sigma, denoiser_hyperparams.sigma_data) * noise_pred
                )

            # --- Apply guidance ---
            action_guidance = _compute_action_guidance(guidance_traj)
            action = noised_traj[:, obs_dim : obs_dim + action_dim]
            guided_action = action + lambd * action_guidance
            noised_traj = noised_traj.at[:, obs_dim : obs_dim + action_dim].set(
                guided_action
            )

        # --- Compute first-order EDM denoise step ---
        rng, _rng = jax.random.split(rng)
        eps = denoiser_hyperparams.s_noise * jax.random.normal(_rng, noised_traj.shape)
        sigma_hat = sigma + gamma * sigma
        # JIT instability when gamma is 0
        traj_hat = jnp.where(
            gamma > 0,
            noised_traj + jnp.sqrt(sigma_hat**2 - sigma**2) * eps,
            noised_traj,
        )
        noise_pred = denoiser_state.apply_fn(
            denoiser_state.params,
            c_in(sigma_hat, denoiser_hyperparams.sigma_data) * traj_hat,
            c_noise(sigma_hat),
        )
        denoised_pred = (
            c_skip(sigma_hat, denoiser_hyperparams.sigma_data) * traj_hat
            + c_out(sigma_hat, denoiser_hyperparams.sigma_data) * noise_pred
        )
        denoised_over_sigma = (traj_hat - denoised_pred) / sigma_hat

        # --- Apply first-order EDM denoise step ---
        denoised_traj = noised_traj + (next_sigma - sigma_hat) * denoised_over_sigma

        # --- Compute EDM second-order correction ---
        if not denoiser_hyperparams.edm_first_order:
            next_noise_pred = denoiser_state.apply_fn(
                denoiser_state.params,
                c_in(next_sigma, denoiser_hyperparams.sigma_data) * denoised_traj,
                c_noise(next_sigma),
            )
            next_denoised_pred = (
                c_skip(next_sigma, denoiser_hyperparams.sigma_data) * denoised_traj
                + c_out(next_sigma, denoiser_hyperparams.sigma_data) * next_noise_pred
            )
            denoised_prime_over_sigma = (denoised_traj - next_denoised_pred) / (
                next_sigma + 1e-9
            )

            # --- Apply second-order EDM denoise step ---
            denoised_traj = jnp.where(
                next_sigma != 0,
                traj_hat
                + 0.5
                * (next_sigma - sigma_hat)
                * (denoised_over_sigma + denoised_prime_over_sigma),
                denoised_traj,
            )

        return (rng, denoised_traj, step_idx + 1), None

    # --- Denoise trajectory ---
    (rng, denoised_traj, _), _ = jax.lax.scan(
        denoise_step,
        (rng, init_noise, 0),
        (sigmas[:-1], sigmas[1:], gammas[:-1]),
    )

    # --- Construct rollout ---
    return construct_rollout(
        denoised_traj,
        denoiser_norm_stats,
        obs_dim,
        action_dim,
    )


# rollout_generator
class SyntheticRolloutGenerator(DatasetRolloutGenerator):
    def __init__(
        self,
        rng,
        args,
        obs_shape,
        action_dim,
        action_lims,
        num_env_steps,
        agent_apply_fn=None,
        batch_size=None,
    ):
        self.num_env_steps = num_env_steps
        self.agent_apply_fn = agent_apply_fn
        self.obs_shape = obs_shape
        self.action_dim = action_dim
        self.action_lims = action_lims
        self.policy_guidance_coeff = args.policy_guidance_coeff
        self.policy_guidance_cosine_coeff = args.policy_guidance_cosine_coeff
        self.num_synth_workers = args.num_synth_workers
        self.num_synth_rollouts = args.num_synth_rollouts

        if not args.denoiser_checkpoint:
            raise ValueError(
                "Must specify generator checkpoint to use synthetic experience"
            )
        self._restore_diffusion_model(args)
        self.obs_stats = self.denoiser_norm_stats["obs"]
        det_guidance = args.agent in DETERMINISTIC_ACTORS
        self.diffusion_sample_fn = partial(
            make_sample_fn(
                self.denoiser_config,
                args.normalize_action_guidance,
                args.denoised_guidance,
                det_guidance,
            ),
            denoiser_state=self.denoiser_state,
            seq_len=self.num_env_steps + 1,
            obs_dim=self.obs_shape[0],
            action_dim=self.action_dim,
            denoiser_norm_stats=self.denoiser_norm_stats,
            policy_guidance_coeff=self.policy_guidance_coeff,
            policy_guidance_cosine_coeff=self.policy_guidance_cosine_coeff,
        )

        # Generate unguided synthetic dataset
        self.update_synthetic_dataset(rng, None)
        if batch_size is None:
            batch_size = args.batch_size
        super().__init__(self._dataset, batch_size)

    def set_apply_fn(self, agent_apply_fn):
        self.agent_apply_fn = agent_apply_fn

    def _generate_single_rollout(self, rng, agent_params):
        return self.diffusion_sample_fn(
            rng=rng, agent_params=agent_params, agent_apply_fn=self.agent_apply_fn
        )

    def update_synthetic_dataset(self, rng, agent_params=None):
        # Regenerate synthetic dataset from the current agent state
        synth_rollouts = []
        batch_rollout_fn = jax.jit(
            jax.vmap(self._generate_single_rollout, in_axes=(0, None))
        )
        for _ in range(self.num_synth_rollouts):
            rng, _rng = jax.random.split(rng)
            _rng = jax.random.split(_rng, self.num_synth_workers)
            synth_rollouts.append(batch_rollout_fn(_rng, agent_params))
        # Stack and flatten rollouts
        self._dataset = jax.jit(
            lambda x: jax.tree_map(
                lambda y: y.reshape((-1, y.shape[-1])), tree_stack(x)
            )
        )(synth_rollouts)

    def _restore_diffusion_model(self, args):
        # Download checkpoint from wandb
        api = wandb.Api()
        ckpt_run = api.run(
            f"{args.wandb_team}/{args.wandb_project}/{args.denoiser_checkpoint}"
        )
        for file in ckpt_run.files():
            file.download(
                root=os.path.join("tmp", args.denoiser_checkpoint), exist_ok=True
            )
        # Create placeholder train state
        ckpt_dict = ckpt_run.config
        self.denoiser_config = Namespace(**ckpt_dict)
        if args.diffusion_timesteps is not None:
            self.denoiser_config.diffusion_timesteps = args.diffusion_timesteps
        placeholder_train_state = create_denoiser_train_state(
            jax.random.PRNGKey(0),
            self.obs_shape[0],
            self.action_dim,
            self.denoiser_config,
            10000,  # Random dataset length to create LR schedule
        )
        # Restore checkpoint into placeholder train state
        ckptr = PyTreeCheckpointer()
        self.denoiser_state = ckptr.restore(
            os.path.join("tmp", args.denoiser_checkpoint, CHECKPOINT_DIR),
            item=placeholder_train_state,
        )
        # Restore normalization statistics
        # Temporary hack, some of the stats are stored as strings
        def conv_str(s):
            s = s.replace("\n", "")
            s = s.replace("[", "")
            s = s.replace("]", "")
            return [float(x) for x in s.split(" ") if x != ""]

        ckpt_dict["norm_stats"] = {
            k: {k1: v if not isinstance(v, str) else conv_str(v) for k1, v in x.items()}
            for k, x in ckpt_dict["norm_stats"].items()
        }
        self.denoiser_norm_stats = {
            attr: {
                stat_name: jnp.array(v, dtype=jnp.float32)
                for stat_name, v in attr_stats.items()
            }
            for attr, attr_stats in ckpt_dict["norm_stats"].items()
        }
        self.denoiser_norm_stats = jax.tree_map(
            lambda x: jnp.expand_dims(x, 0) if len(x.shape) == 0 else x,
            self.denoiser_norm_stats,
        )
        print(f"Restored synthetic rollout generator from {args.denoiser_checkpoint}")


class MixedRolloutGenerator:
    """Rollout generator with mixed real and synthetic rollouts"""

    def __init__(
        self,
        rng,
        args,
        obs_shape,
        action_dim,
        action_lims,
        num_env_steps,
        agent_apply_fn=None,
    ):
        # TODO: remove these as input (here and others) and compute them from the dataset
        self.obs_shape = obs_shape
        self.action_dim = action_dim
        self.action_lims = action_lims
        assert 0 < args.synth_batch_size <= args.batch_size
        self.synth_batch_size = args.synth_batch_size
        self.real_batch_size = args.batch_size - self.synth_batch_size
        self.synth_batch_lifetime = args.synth_batch_lifetime
        assert self.synth_batch_size % self.synth_batch_lifetime == 0
        self.synth_batch_size = self.synth_batch_size // self.synth_batch_lifetime
        self.synth_rollout_gens = []
        for _ in range(args.synth_batch_lifetime):
            rng, _rng = jax.random.split(rng)
            self.synth_rollout_gens.append(
                SyntheticRolloutGenerator(
                    _rng,
                    args,
                    obs_shape,
                    action_dim,
                    action_lims,
                    num_env_steps,
                    agent_apply_fn,
                    self.synth_batch_size,
                )
            )
        self.synth_batch_pointer = 0
        if self.real_batch_size > 0:
            self.real_rollout_gen = OfflineRolloutGenerator(
                args,
                obs_shape,
                action_dim,
                action_lims,
                num_env_steps,
                agent_apply_fn,
                self.real_batch_size,
            )
            self.obs_stats = self.real_rollout_gen.obs_stats
        else:
            print("WARNING: real batch size is 0, using only synthetic rollouts")

    def update_synthetic_dataset(self, rng, agent_params=None):
        # Regenerate synthetic dataset from the current agent state
        self.synth_rollout_gens[self.synth_batch_pointer].update_synthetic_dataset(
            rng, agent_params
        )
        self.synth_batch_pointer = (
            self.synth_batch_pointer + 1
        ) % self.synth_batch_lifetime

    def set_apply_fn(self, agent_apply_fn):
        for i in range(self.synth_batch_lifetime):
            self.synth_rollout_gens[i].set_apply_fn(agent_apply_fn)
        if self.real_batch_size > 0:
            self.real_rollout_gen.set_apply_fn(agent_apply_fn)

    def batch_rollout(self, rng):
        flattened_batches = []
        for i in range(self.synth_batch_lifetime):
            rng, _rng = jax.random.split(rng)
            flattened_synth_batch = self.synth_rollout_gens[i].batch_rollout(_rng)
            flattened_batches.append(flattened_synth_batch)
        if self.real_batch_size > 0:
            rng, _rng = jax.random.split(rng)
            flattened_real_batch = self.real_rollout_gen.batch_rollout(_rng)
            flattened_batches.append(flattened_real_batch)
        traj_batch = jax.tree_map(
            lambda *x: jnp.concatenate([batch for batch in x], axis=0),
            *flattened_batches,
        )
        return jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), traj_batch)


# diffuser (diffusion.py)
def create_denoiser_train_state(rng, obs_dim, action_dim, args, dataset_len):
    # --- Create U-Net model ---
    denoiser = UNet(args.num_features, args.num_blocks)
    placeholder_batch, placeholder_seq = 2, 64
    denoiser_params = denoiser.init(
        rng,
        jnp.ones((placeholder_batch, placeholder_seq, obs_dim + action_dim + 2)),
        jnp.ones((1,)),
    )

    # --- Create cosine decay schedule ---
    num_steps_per_epoch = dataset_len // args.batch_size
    total_steps = num_steps_per_epoch * args.num_epochs
    warmup_steps = total_steps // 10
    lr_schedule = optax.warmup_cosine_decay_schedule(
        init_value=args.lr * 0.1,
        peak_value=args.lr,
        warmup_steps=warmup_steps,
        decay_steps=total_steps,
    )

    # --- Return train state ---
    return TrainState.create(
        apply_fn=denoiser.apply,
        params=denoiser_params,
        tx=optax.adam(learning_rate=lr_schedule),
    )


def get_denoiser_hypers(args):
    if args.diffusion_method == "edm":
        return DenoiserHyperparams(
            p_mean=args.edm_p_mean,
            p_std=args.edm_p_std,
            sigma_data=args.edm_sigma_data,
            sigma_min=args.edm_sigma_min,
            sigma_max=args.edm_sigma_max,
            rho=args.edm_rho,
            edm_first_order=args.edm_first_order,
            diffusion_timesteps=args.diffusion_timesteps,
            s_tmin=args.edm_s_tmin,
            s_tmax=args.edm_s_tmax,
            s_churn=args.edm_s_churn,
            s_noise=args.edm_s_noise,
        )
    raise ValueError(f"Unknown diffusion method {args.diffusion_method}.")


def make_train_step(args):
    hypers = get_denoiser_hypers(args)
    if args.diffusion_method == "edm":
        return partial(train_step, denoiser_hyperparams=hypers)
    raise ValueError(f"Unknown diffusion method {args.diffusion_method}.")


def make_sample_fn(
    args,
    normalize_action_guidance,
    denoised_guidance,
    det_guidance,
):
    hypers = get_denoiser_hypers(args)
    if args.diffusion_method == "edm":
        return partial(
            sample_trajectory,
            denoiser_hyperparams=hypers,
            normalize_action_guidance=normalize_action_guidance,
            denoised_guidance=denoised_guidance,
            det_guidance=det_guidance,
        )
    raise ValueError(f"Unknown diffusion method {args.diffusion_method}.")