from argparse import Namespace
from functools import partial
import random
import pickle
import os
import sys
import numpy as np
from tqdm import tqdm

import jax
import jax.numpy as jnp
from orbax.checkpoint import PyTreeCheckpointer

from opelab.core.baseline import Baseline
from opelab.core.baselines.pgd.util import *
from opelab.core.baselines.pgd.diffusion import get_denoiser_hypers, create_denoiser_train_state


def c_skip(sigma, sigma_data):
    return (sigma_data**2) / (sigma**2 + sigma_data**2)


def c_out(sigma, sigma_data):
    return sigma * sigma_data * ((sigma_data**2 + sigma**2) ** -0.5)


def c_in(sigma, sigma_data):
    return (sigma**2 + sigma_data**2) ** -0.5


def c_noise(sigma):
    return jnp.log(sigma) * 0.25
    

# --- Construct guidance function ---
@partial(jax.jit, static_argnames=['denoised_guidance', 'obs_dim', 'action_dim'])
def denoise_step_before_guidance(
    obs_dim, action_dim,
    denoiser_hyperparams, 
    denoiser_norm_stats,
    policy_guidance_coeff,
    policy_guidance_delay_steps,
    policy_guidance_cosine_coeff,
    denoised_guidance,
    runner_state, 
    step_coeffs
):
    rng, noised_traj, step_idx = runner_state
    sigma, next_sigma, gamma = step_coeffs

    # --- Compute guidance coefficient ---
    n_steps = denoiser_hyperparams.diffusion_timesteps
    lambd = 1.0 - (step_idx / n_steps)
    cosine_adjustment = jnp.sin(jnp.pi * ((step_idx + 1) / n_steps))
    lambd += policy_guidance_cosine_coeff * cosine_adjustment
    do_apply_guidance_this_step = jnp.logical_and(
        step_idx >= policy_guidance_delay_steps,
        step_idx < n_steps - 1,
    )
    lambd = jnp.where(
        do_apply_guidance_this_step, policy_guidance_coeff * lambd, 0.0
    )

    # --- Compute denoised trajectory for guidance ---
    guidance_traj = noised_traj
    if denoised_guidance:
        noise_pred = denoiser_state.apply_fn(
            denoiser_state.params,
            c_in(sigma, denoiser_hyperparams.sigma_data) * noised_traj,
            c_noise(sigma),
        )
        guidance_traj = (
            c_skip(sigma, denoiser_hyperparams.sigma_data) * noised_traj
            + c_out(sigma, denoiser_hyperparams.sigma_data) * noise_pred
        )

    # --- Apply guidance ---
    obs = guidance_traj[:, :obs_dim]
    obs = unnormalise_traj(obs, denoiser_norm_stats["obs"])
    action = guidance_traj[:, obs_dim : obs_dim + action_dim]
    action = unnormalise_traj(action, denoiser_norm_stats["action"])
    action = jnp.tanh(action)
    return noised_traj, obs, action, lambd, step_coeffs, rng, step_idx


@partial(jax.jit, static_argnames=['denoiser_hyperparams', 'normalize_action_guidance', 'obs_dim', 'action_dim'])
def denoise_step_after_guidance(
    obs_dim, action_dim,
    denoiser_state,
    denoiser_hyperparams, 
    normalize_action_guidance, 
    noised_traj, 
    action, 
    lambd, 
    step_coeffs, 
    rng, 
    step_idx, 
    action_guidance
):
    sigma, next_sigma, gamma = step_coeffs
    if normalize_action_guidance:
        action_guidance = action_guidance / (jnp.linalg.norm(action_guidance) + 1e-8)
    
    action = noised_traj[:, obs_dim : obs_dim + action_dim]
    guided_action = action + lambd * action_guidance
    noised_traj = noised_traj.at[:, obs_dim : obs_dim + action_dim].set(
        guided_action
    )

    # --- Compute first-order EDM denoise step ---
    rng, _rng = jax.random.split(rng)
    eps = denoiser_hyperparams.s_noise * jax.random.normal(_rng, noised_traj.shape)
    sigma_hat = sigma + gamma * sigma
    # JIT instability when gamma is 0
    traj_hat = jnp.where(
        gamma > 0,
        noised_traj + jnp.sqrt(sigma_hat**2 - sigma**2) * eps,
        noised_traj,
    )
    noise_pred = denoiser_state.apply_fn(
        denoiser_state.params,
        c_in(sigma_hat, denoiser_hyperparams.sigma_data) * traj_hat,
        c_noise(sigma_hat),
    )
    denoised_pred = (
        c_skip(sigma_hat, denoiser_hyperparams.sigma_data) * traj_hat
        + c_out(sigma_hat, denoiser_hyperparams.sigma_data) * noise_pred
    )
    denoised_over_sigma = (traj_hat - denoised_pred) / sigma_hat

    # --- Apply first-order EDM denoise step ---
    denoised_traj = noised_traj + (next_sigma - sigma_hat) * denoised_over_sigma

    # --- Compute EDM second-order correction ---
    if not denoiser_hyperparams.edm_first_order:
        next_noise_pred = denoiser_state.apply_fn(
            denoiser_state.params,
            c_in(next_sigma, denoiser_hyperparams.sigma_data) * denoised_traj,
            c_noise(next_sigma),
        )
        next_denoised_pred = (
            c_skip(next_sigma, denoiser_hyperparams.sigma_data) * denoised_traj
            + c_out(next_sigma, denoiser_hyperparams.sigma_data) * next_noise_pred
        )
        denoised_prime_over_sigma = (denoised_traj - next_denoised_pred) / (
            next_sigma + 1e-9
        )

        # --- Apply second-order EDM denoise step ---
        denoised_traj = jnp.where(
            next_sigma != 0,
            traj_hat
            + 0.5
            * (next_sigma - sigma_hat)
            * (denoised_over_sigma + denoised_prime_over_sigma),
            denoised_traj,
        )

    return (rng, denoised_traj, step_idx + 1), None


construct_rollout_func = jax.jit(construct_rollout, static_argnames=['obs_dim', 'action_dim'])


@partial(jax.jit, static_argnames=['denoiser_hyperparams'])
def compute_noise_schedule(denoiser_hyperparams):
    inv_rho = 1 / denoiser_hyperparams.rho
    sigmas = (
            denoiser_hyperparams.sigma_max**inv_rho
            + (jnp.arange(denoiser_hyperparams.diffusion_timesteps + 1) / (denoiser_hyperparams.diffusion_timesteps - 1))
            * (
                denoiser_hyperparams.sigma_min**inv_rho
                - denoiser_hyperparams.sigma_max**inv_rho
            )
        ) ** denoiser_hyperparams.rho
    sigmas.at[-1].set(0.0)
    gammas = jnp.where(
        (sigmas >= denoiser_hyperparams.s_tmin)
        & (sigmas <= denoiser_hyperparams.s_tmax),
        jnp.minimum(
            denoiser_hyperparams.s_churn / denoiser_hyperparams.diffusion_timesteps,
            jnp.sqrt(2) - 1,
        ),
        0.0,
    )
    return sigmas, gammas


def sample_trajectory(
    rng,
    denoiser_state,
    seq_len,
    obs_dim,
    action_dim,
    denoiser_norm_stats,
    denoiser_hyperparams,
    policy_guidance_coeff=0.0,
    policy_guidance_delay_steps=0,
    policy_guidance_cosine_coeff=0.3,
    normalize_action_guidance=True,
    denoised_guidance=False,
    target_grad_log_prob_fn=None
):
    # --- Compute noise schedule ---    
    sigmas, gammas = compute_noise_schedule(denoiser_hyperparams)

    # --- Sample random noise trajecory ---
    rng, _rng = jax.random.split(rng)
    # Add 2 dimensions for reward and done
    init_noise = jax.random.normal(_rng, (seq_len, obs_dim + action_dim + 2))
    init_noise *= sigmas[0]

    carry = (rng, init_noise, 0)
    for i in range(sigmas[:-1].shape[0]):
        x = (sigmas[:-1][i], sigmas[1:][i], gammas[:-1][i])
        noised_traj, obs, action, lambd, step_coeffs, rng, step_idx = denoise_step_before_guidance(
            obs_dim, action_dim, 
            denoiser_hyperparams, denoiser_norm_stats,
            policy_guidance_coeff, policy_guidance_delay_steps, policy_guidance_cosine_coeff, denoised_guidance, 
            carry, x
        )

        action_guidance = 0.0
        if target_grad_log_prob_fn is not None:
            action_guidance = target_grad_log_prob_fn(obs, action)
        
        carry, _ = denoise_step_after_guidance(
            obs_dim, action_dim,
            denoiser_state, denoiser_hyperparams, normalize_action_guidance, 
            noised_traj, action, lambd, step_coeffs, rng, step_idx, action_guidance
        )
    (rng, denoised_traj, _) = carry

    # --- Construct rollout ---
    return construct_rollout_func(
        denoised_traj,
        denoiser_norm_stats,
        obs_dim,
        action_dim,
    )


def unpack_checkpoint(file):
    if not os.path.isdir(file):
        import zipfile
        with zipfile.ZipFile(file + '.zip', 'r') as zip_ref:
            zip_ref.extractall(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chkpt'))


def load_diffusion_model(args):
    file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chkpt', args.dataset_name.replace('-', '_'))

    # Create placeholder train state
    unpack_checkpoint(file)
    with open(file + '_config.pkl', 'rb') as f: denoiser_config = pickle.load(f)
    with open(file + '_norm_stats.pkl', 'rb') as f: norm_stats = pickle.load(f)
    with open(file + '_info.pkl', 'rb') as f: denoiser_info = pickle.load(f)
    ckpt_dict = vars(denoiser_config)
    if args.diffusion_timesteps is not None:
        denoiser_config.diffusion_timesteps = args.diffusion_timesteps
    placeholder_train_state = create_denoiser_train_state(
        jax.random.PRNGKey(0),
        denoiser_info['obs_dim'],
        denoiser_info['action_dim'],
        denoiser_config,
        10000,  # Random dataset length to create LR schedule
    )
    # Restore checkpoint into placeholder train state
    ckptr = PyTreeCheckpointer()
    denoiser_state = ckptr.restore(
        file,
        item=placeholder_train_state,
    )
    # Restore normalization statistics
    # Temporary hack, some of the stats are stored as strings
    def conv_str(s):
        s = s.replace("\n", "")
        s = s.replace("[", "")
        s = s.replace("]", "")
        return [float(x) for x in s.split(" ") if x != ""]

    norm_stats = {
        k: {k1: v if not isinstance(v, str) else conv_str(v) for k1, v in x.items()}
        for k, x in norm_stats.items()
    }
    denoiser_norm_stats = {
        attr: {
            stat_name: jnp.array(v, dtype=jnp.float32)
            for stat_name, v in attr_stats.items()
        }
        for attr, attr_stats in norm_stats.items()
    }
    denoiser_norm_stats = jax.tree_map(
        lambda x: jnp.expand_dims(x, 0) if len(x.shape) == 0 else x,
        denoiser_norm_stats,
    )
    print(f"Restored synthetic rollout generator for {args.dataset_name}.")
    return denoiser_config, denoiser_state, denoiser_norm_stats, denoiser_info


def ope_diffusion(args, gamma, target_grad_log_prob_fn=None):
    denoiser_config, denoiser_state, denoiser_norm_stats, denoiser_info = load_diffusion_model(args)
    hypers = get_denoiser_hypers(denoiser_config)

    _generate_single_rollout = partial(
        sample_trajectory,
        denoiser_state=denoiser_state,
        seq_len=args.num_rollout_steps + 1,
        obs_dim=denoiser_info['obs_dim'],
        action_dim=denoiser_info['action_dim'],
        denoiser_norm_stats=denoiser_norm_stats,
        denoiser_hyperparams=hypers,
        policy_guidance_coeff=args.policy_guidance_coeff,
        policy_guidance_delay_steps=0,
        policy_guidance_cosine_coeff=args.policy_guidance_cosine_coeff,
        normalize_action_guidance=args.normalize_action_guidance,
        denoised_guidance=args.denoised_guidance,        
        target_grad_log_prob_fn=target_grad_log_prob_fn,
    )

    rng = jax.random.PRNGKey(random.randint(0, sys.maxsize))

    # Regenerate synthetic dataset from the current agent state
    def batch_rollout_fn(rngs):
        trajs = []
        for _rng in rngs:
            trajs.append(_generate_single_rollout(rng=_rng))
        return trajs

    rewards = []
    dones = []
    obses = []
    actions = []
    for it in tqdm(range(args.num_synth_rollouts)):
        rng, _rng = jax.random.split(rng)
        _rng = jax.random.split(_rng, args.num_synth_workers)
        transitions = batch_rollout_fn(_rng)
        for transition in transitions:
            rewards.append(transition.reward)
            dones.append(transition.done)
            obses.append(transition.obs)
            actions.append(transition.action)
    rewards = np.stack(rewards, axis=0)[:,:,0]
    dones = np.stack(dones, axis=0)[:,:,0]
    #obses = np.squeeze(np.stack(obses, axis=0))
    #actions = np.squeeze(np.stack(actions, axis=0))
    #tensor = np.concatenate([obses, actions], axis=1)
    gammas = gamma ** np.arange(rewards.shape[1])
    disc_rewards = rewards * (1 - dones) * gammas.reshape((1, -1))
    return np.mean(np.sum(disc_rewards, axis=1))


class PolicyGuidedDiffusion(Baseline):

    def __init__(self, 
                 device,
                 dataset_name,
                 T,
                 num_samples,
                 policy_guidance_coeff=1.0, 
                 policy_guidance_cosine_coeff=0.3, 
                 normalize_action_guidance=True,  # causes NaN when no guidance
                 denoised_guidance=False,
                 target_model=None):        
        agent_args = [
            '--dataset_name', str(dataset_name),
            '--num_rollout_steps', str(T),
            '--num_synth_rollouts', str(num_samples),
            '--num_synth_workers', '1',
            '--policy_guidance_coeff', str(policy_guidance_coeff),
            '--policy_guidance_cosine_coeff', str(policy_guidance_cosine_coeff)
        ]
        if normalize_action_guidance:
            agent_args += ['--normalize_action_guidance']
        if denoised_guidance:
            agent_args += ['--denoised_guidance']
        print(agent_args)
        self.args = parse_agent_args(agent_args)
        self.device = device

        if target_model is not None:
            self.target_model = target_model
            self.target_model.to(device)
            self.target_fn = self.target_model.grad_log_prob_extended_pgd
        else:
            self.target_model = None
            self.target_fn = None

    
    def evaluate(self, data=None, target=None, behavior=None, gamma=1.0, reward_estimator=None):
        
        if target is not None:
            self.target_model = target
            self.target_model.to(self.device)
            self.target_fn = self.target_model.grad_log_prob_extended_pgd
        else:
            self.target_model = None
            self.target_fn = None

        debug = self.args.debug
        debug_nans = self.args.debug_nans
        if debug_nans:
            jax.config.update("jax_debug_nans", True)

        if debug:
            with jax.disable_jit():
                return ope_diffusion(self.args, gamma, target_grad_log_prob_fn=self.target_fn)
        else:
            return ope_diffusion(self.args, gamma, target_grad_log_prob_fn=self.target_fn)


if __name__ == "__main__":
    baseline = PolicyGuidedDiffusion(
        device='cuda', dataset_name='hopper-medium-v2', T=768, num_samples=5)
    print(baseline.evaluate())