import os, time, pdb

import d4rl
import gym
import hydra, wandb, uuid
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from cleandiffuser.classifier import CumRewClassifier
from cleandiffuser.dataset.d4rl_antmaze_dataset import D4RLAntmazeSeqDataset
from cleandiffuser.dataset.dataset_utils import loop_dataloader, loop_two_dataloaders
from cleandiffuser.diffusion import ContinuousDiffusionSDE, DiscreteDiffusionSDE, ContinuousEDM, ContinuousConsistencyTrajecoteryModel
from cleandiffuser.invdynamic import MlpInvDynamic
from cleandiffuser.nn_condition import MLPCondition, IdentityCondition
from cleandiffuser.nn_diffusion import DiT1d, DAMlp, DiT1d2n
from cleandiffuser.nn_classifier import HalfJannerUNet1d
from cleandiffuser.nn_diffusion import JannerUNet1d
from cleandiffuser.utils import report_parameters, DD_RETURN_SCALE, DAHorizonCritic, IDQLVNet
from utils import set_seed
from tqdm import tqdm
from omegaconf import OmegaConf
from cleandiffuser.utils.iql import IQL



@hydra.main(config_path="../configs/veteran/antmaze", config_name="antmaze", version_base=None)
def pipeline(args):
    args.device = args.device if torch.cuda.is_available() else "cpu"
    if args.enable_wandb and args.mode in ["inference", "train"]:
        wandb.require("core")
        print(args)
        wandb.init(
            reinit=True,
            id=str(uuid.uuid4()),
            project=str(args.project),
            group=str(args.group),
            name=str(args.name),
            config=OmegaConf.to_container(args, resolve=True)
        )

    set_seed(args.seed)
    
    # base config
    base_path = f"{args.pipeline_name}_H{args.task.planner_horizon}_Jump{args.task.stride}"
    base_path += f"_next{args.planner_next_obs_loss_weight}"
    # guidance type
    base_path += f"_{args.guidance_type}"
    # For Planner
    base_path += f"_{args.planner_net}"
    if args.planner_net == "transformer":
        base_path += f"_d{args.planner_depth}"
        base_path += f"_width{args.planner_d_model}"
    elif args.planner_net == "unet":
        base_path += f"_width{args.unet_dim}"
    
    if not args.planner_predict_noise:
        base_path += f"_pred_x0"
    
    # pipeline_type
    base_path += f"_{args.pipeline_type}"
    base_path += f"_dp{args.use_diffusion_policy}"
    # task name
    base_path += f"/{args.task.env_name}/"
    
    origin_path = "/data/results/" + base_path
    video_path = "/data/video_outputs/" + base_path
    save_path = "/data/results/" + base_path

    if os.path.exists(save_path) is False:
        os.makedirs(save_path)
    
    if os.path.exists(video_path) is False:
        os.makedirs(video_path)

    # ---------------------- Create Dataset ----------------------
    env = gym.make(args.task.env_name)
    planner_dataset = D4RLAntmazeSeqDataset(
        env.get_dataset(), horizon=args.task.planner_horizon, discount=args.reward_mode.discount, 
        continous_reward_at_done=args.reward_mode.continous_reward_at_done, reward_tune=args.reward_mode.reward_tune, 
        stride=args.task.stride, learn_policy=False, center_mapping=(args.guidance_type!="cfg")
    )

    planner_dataloader = DataLoader(
        planner_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    obs_dim, act_dim = planner_dataset.o_dim, planner_dataset.a_dim

    planner_dim = obs_dim if args.pipeline_type=="separate" else obs_dim + act_dim

    # --------------- Network Architecture -----------------

    nn_diffusion_edm = DiT1d(
        planner_dim, emb_dim=args.planner_emb_dim, timestep_emb_params={"scale": 0.02}, dropout=0.2, 
        d_model=args.planner_d_model, n_heads=args.planner_d_model//64, depth=args.planner_depth, timestep_emb_type="fourier")

    nn_diffusion_planner = DiT1d2n(
        planner_dim, emb_dim=args.planner_emb_dim, timestep_emb_params={"scale": 0.02}, # dropout=0.2, 
        d_model=args.planner_d_model, n_heads=args.planner_d_model//64, depth=args.planner_depth, timestep_emb_type="fourier")

    critic = DAHorizonCritic(
        planner_dim, emb_dim=args.planner_emb_dim,
        d_model=args.planner_d_model, n_heads=args.planner_d_model//64, depth=2, norm_type="pre").to(args.device)

    if args.gan_training:
        discriminator = DAHorizonCritic(
            planner_dim, emb_dim=args.planner_emb_dim,
            d_model=args.planner_d_model, n_heads=args.planner_d_model//64, depth=2, norm_type="pre").to(args.device)
    else:
        discriminator = None
    
    print(f"=============== Parameter Report of Planner ==================================")
    report_parameters(nn_diffusion_planner)
    print(f"==============================================================================")

    # ----------------- Masking -------------------
    fix_mask = torch.zeros((args.task.planner_horizon, planner_dim))
    fix_mask[0, :obs_dim] = 1.
    loss_weight = torch.ones((args.task.planner_horizon, planner_dim))
    loss_weight[1] = args.planner_next_obs_loss_weight

    # --------------- Diffusion Model with Classifier-Free Guidance --------------------
    edm_planner = ContinuousEDM(
        nn_diffusion_edm, nn_condition=None, sigma_data=1.0,
        fix_mask=fix_mask, loss_weight=loss_weight, classifier=None, 
        ema_rate=args.planner_ema_rate, device=args.device)

    edm_planner.load(origin_path + f"planner_ckpt_latest.pt")
    edm_planner.eval()

    planner = ContinuousConsistencyTrajecoteryModel(
        nn_diffusion_planner, nn_condition=None, sigma_data=1.0, optim_params={"lr": args.lr}, dloss_start_itr=args.dloss_start_itr,
        fix_mask=fix_mask, loss_weight=loss_weight, classifier=None, discriminator=discriminator,
        ema_rate=args.distiller_ema_rate, device=args.device, d_optim_params={"lr": args.d_lr},)

    planner.prepare_distillation(edm_planner, start_scales = 40, num_heun_step = 20)
    planner.train()

    # --------------- Inverse Dynamic (Policy) -------------------
    if args.use_diffusion_policy:
        nn_diffusion_policy = DAMlp(obs_dim, act_dim, emb_dim=64, hidden_dim=args.policy_hidden_dim, timestep_emb_type="positional").to(args.device)
        nn_condition_policy = IdentityCondition(dropout=0.0).to(args.device)
        print(f"=============== Parameter Report of Policy ===================================")
        report_parameters(nn_diffusion_policy)
        print(f"==============================================================================")
        # --------------- Diffusion Model Actor --------------------
        policy = DiscreteDiffusionSDE(
            nn_diffusion_policy, nn_condition_policy, predict_noise=args.policy_predict_noise, optim_params={"lr": args.policy_learning_rate},
            x_max=+1. * torch.ones((1, act_dim), device=args.device),
            x_min=-1. * torch.ones((1, act_dim), device=args.device),
            diffusion_steps=args.policy_diffusion_steps, ema_rate=args.policy_ema_rate, device=args.device)
    else:
        invdyn = MlpInvDynamic(obs_dim, act_dim, 512, nn.Tanh(), {"lr": 2e-4}, device=args.device)


    # ---------------------- Training ----------------------
    if args.mode == "train":

        n_gradient_step = 0
        log = {
            "loss": 0,
            "denoise_weight": 0., #
            "g_loss": 0., #
            "logits_real": 0., #
            "logits_fake": 0., #
        }
        
        pbar = tqdm(total=args.distill_gradient_steps/args.log_interval)
        for planner_batch in loop_dataloader(planner_dataloader):

            planner_horizon_obs = planner_batch["obs"]["state"].to(args.device)
            planner_horizon_data = planner_horizon_obs
            

            # ----------- Planner Gradient Step ------------
            # loss = planner.update(planner_horizon_data, loss_type="distillation")["loss"]
            # log["loss"] += loss
            train_log = planner.update(planner_horizon_data, loss_type="distillation")
            for key, val in train_log.items():
                log[key] += val
            

            # ----------- Logging ------------
            if (n_gradient_step + 1) % args.log_interval == 0:
                log["gradient_steps"] = n_gradient_step + 1
                log["loss"] /= 1000
                log["denoise_weight"] /= args.log_interval #
                log["g_loss"] = log["g_loss"] / args.log_interval * 2 #
                log["logits_real"] = log["logits_real"] / args.log_interval * 2 #
                log["logits_fake"] = log["logits_fake"] / args.log_interval * 2 #
                print(log)

                if args.enable_wandb:
                    wandb.log(log, step=n_gradient_step + 1)
                pbar.update(1)
                log = {
                    "loss": 0.,
                    "denoise_weight": 0.,
                    "g_loss": 0.,
                    "logits_real": 0.,
                    "logits_fake": 0.,
                    }

            # ----------- Saving ------------
            if (n_gradient_step + 1) % args.save_interval == 0:
                planner.save(save_path + f"ctm_ckpt_{n_gradient_step + 1}.pt")
                planner.save(save_path + f"ctm_ckpt_latest.pt")

            n_gradient_step += 1
            if n_gradient_step >= args.distill_gradient_steps: # 
                break

            
    # ---------------------- Inference ----------------------
    elif args.mode == "inference":
        
        if args.guidance_type=="MCSS":
            # load planner
            planner.load(save_path + f"ctm_ckpt_{args.planner_ckpt}.pt")
            planner.eval()
            # load critic
            critic_ckpt = torch.load(origin_path + f"critic_ckpt_{args.critic_ckpt}.pt")
            critic.load_state_dict(critic_ckpt["critic"])
            critic.eval()
            # load policy
            if args.pipeline_type == "separate":
                if args.use_diffusion_policy:
                    policy.load(origin_path + f"policy_ckpt_{args.policy_ckpt}.pt")
                    policy.eval()
                else:
                    invdyn.load(save_path + f"invdyn_ckpt_{args.invdyn_ckpt}.pt")
                    invdyn.eval()
        
        elif args.guidance_type=="cfg":
            # load planner
            planner.load(save_path + f"planner_ckpt_{args.planner_ckpt}.pt")
            planner.eval()
            # load policy
            if args.pipeline_type == "separate":
                if args.use_diffusion_policy:
                    policy.load(save_path + f"policy_ckpt_{args.policy_ckpt}.pt")
                    policy.eval()
                else:
                    invdyn.load(save_path + f"invdyn_ckpt_{args.invdyn_ckpt}.pt")
                    invdyn.eval()
            
        elif args.guidance_type=="cg":
            # load planner
            planner.load(save_path + f"planner_ckpt_{args.planner_ckpt}.pt")
            # load classifier
            planner.classifier.load(save_path + f"classifier_ckpt_{args.planner_ckpt}.pt")
            planner.eval()
            # load policy
            if args.pipeline_type == "separate":
                if args.use_diffusion_policy:
                    policy.load(save_path + f"policy_ckpt_{args.policy_ckpt}.pt")
                    policy.eval()
                else:
                    invdyn.load(save_path + f"invdyn_ckpt_{args.invdyn_ckpt}.pt")
                    invdyn.eval()
                    
        
        MAX_VALUE_STEPS = 1_000_000
        

        env_eval = gym.vector.make(args.task.env_name, args.num_envs)
        normalizer = planner_dataset.get_normalizer()
        episode_rewards = []
        time_list = []
        
        for i in range(args.num_episodes):
            obs, ep_reward, cum_done, t = env_eval.reset(), 0., 0., 0
            while not np.all(cum_done) and t < args.task.max_path_length + 1:
                
                # 1) generate plan
                if args.guidance_type == "MCSS":
                    planner_prior = torch.zeros((args.num_envs * args.planner_num_candidates, args.task.planner_horizon, planner_dim), device=args.device)
                    
                    obs = torch.tensor(normalizer.normalize(obs), device=args.device, dtype=torch.float32)
                    obs_repeat = obs.unsqueeze(1).repeat(1, args.planner_num_candidates, 1).view(-1, obs_dim)

                    # sample trajectories
                    planner_prior[:, 0, :obs_dim] = obs_repeat
                    start_time = time.time()
                    traj, log = planner.sample(
                        planner_prior, 
                        n_samples=args.num_envs * args.planner_num_candidates, sample_steps=args.ctm_sampling_steps, use_ema=args.planner_use_ema,
                        condition_cfg=None, w_cfg=1.0, temperature=args.task.planner_temperature)
         
                    # resample
                    with torch.no_grad():
                        value = critic(traj)
                        value = value.view(args.num_envs, args.planner_num_candidates)
                        idx = torch.argmax(value, -1)
                        traj = traj.reshape(args.num_envs, args.planner_num_candidates, args.task.planner_horizon, planner_dim)
                        traj = traj[torch.arange(args.num_envs), idx]
                
                elif args.guidance_type == "cfg":
                    planner_prior = torch.zeros((args.num_envs, args.task.planner_horizon, planner_dim), device=args.device)
                    condition = torch.ones((args.num_envs, 1), device=args.device) * args.task.planner_target_return
                    
                    obs = torch.tensor(normalizer.normalize(obs), device=args.device, dtype=torch.float32)

                    # sample trajectories
                    planner_prior[:, 0, :obs_dim] = obs
                    traj, log = planner.sample(
                        planner_prior, solver=args.planner_solver,
                        n_samples=args.num_envs, sample_steps=args.planner_sampling_steps, use_ema=args.planner_use_ema,
                        condition_cfg=condition, w_cfg=args.task.planner_w_cfg, temperature=args.task.planner_temperature)
                
                elif args.guidance_type == "cg":
                    planner_prior = torch.zeros((args.num_envs * args.planner_num_candidates, args.task.planner_horizon, planner_dim), device=args.device)
                    
                    obs = torch.tensor(normalizer.normalize(obs), device=args.device, dtype=torch.float32)
                    obs_repeat = obs.unsqueeze(1).repeat(1, args.planner_num_candidates, 1).view(-1, obs_dim)
                    
                    planner_prior[:, 0, :obs_dim] = obs_repeat
                    traj, log = planner.sample(
                        planner_prior, solver=args.planner_solver,
                        n_samples=args.num_envs * args.planner_num_candidates, sample_steps=args.planner_sampling_steps, use_ema=args.planner_use_ema,
                        w_cg=args.task.planner_w_cfg, temperature=args.task.planner_temperature)
                    
                    # resample
                    with torch.no_grad():
                        logp = log["log_p"].view(args.num_envs, args.planner_num_candidates)
                        idx = torch.argmax(logp, -1)
                        traj = traj.reshape(args.num_envs, args.planner_num_candidates, args.task.planner_horizon, planner_dim)
                        traj = traj[torch.arange(args.num_envs), idx]

                # 2) generate action
                if args.pipeline_type == "separate":
                    if args.use_diffusion_policy:
                        policy_prior = torch.zeros((args.num_envs, act_dim), device=args.device)
                        with torch.no_grad():
                            next_obs_plan = traj[:, 1, :]
                            obs_policy = obs.clone()
                            next_obs_policy = next_obs_plan.clone()
                            
                            
                            if args.rebase_policy:
                                next_obs_policy[:, :2] -= obs_policy[:, :2]
                                obs_policy[:, :2] = 0
                            
                            act, log = policy.sample(
                                policy_prior,
                                solver=args.policy_solver,
                                n_samples=args.num_envs,
                                sample_steps=args.policy_sampling_steps,
                                condition_cfg=torch.cat([obs_policy, next_obs_policy], dim=-1), w_cfg=1.0,
                                use_ema=args.policy_use_ema, temperature=args.policy_temperature)
                            act = act.cpu().numpy()
                    else:
                        # inverse dynamic
                        with torch.no_grad():
                            act = invdyn.predict(obs, traj[:, 1, :]).cpu().numpy()
                else:
                    act = traj[:, 0, obs_dim:]
                    act = act.cpu().numpy()
                    
                # step
                end_time = time.time()
                time_list.append(end_time-start_time)
                obs, rew, done, info = env_eval.step(act)

                t += 1
                cum_done = done if cum_done is None else np.logical_or(cum_done, done)
                ep_reward += rew
                # print(f'[t={t}] xy: {np.around(obs[:, :2], 2)}')
                print(f'[t={t}] rew: {ep_reward}')

            episode_rewards.append(np.clip(ep_reward, 0., 1.))

        episode_rewards = [list(map(lambda x: env.get_normalized_score(x), r)) for r in episode_rewards]
        episode_rewards = np.array(episode_rewards).reshape(-1) * 100
        mean = np.mean(episode_rewards)
        err = np.std(episode_rewards) / np.sqrt(len(episode_rewards))
        print(mean, err)

        if args.enable_wandb:
            wandb.log({'Mean Reward': mean, 'Error': err})
            wandb.finish()
        print('time mean: ', 1000 * np.mean(time_list), 'err: ', 1000 * np.std(time_list) / np.sqrt(len(time_list)))
        
    else:
        raise ValueError(f"Invalid mode: {args.mode}")


if __name__ == "__main__":
    pipeline()
