 # python extension_d4rl_mujoco.py task=halfcheetah-medium-expert-v2  
import copy
import os
from tqdm import tqdm
import d4rl
import gym
import hydra
from datetime import datetime
from sklearn.ensemble import IsolationForest
import pickle as pkl
import numpy as np
import torch
import torch.nn as nn
import random
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import time
import sys 
from cleandiffuser.dataset.dataset_utils import loop_dataloader
from cleandiffuser.diffusion import ContinuousDiffusionSDE
from cleandiffuser.invdynamic import MlpInvDynamic, RewardPredictor
from cleandiffuser.nn_condition import MLPCondition
from cleandiffuser.nn_diffusion import DiT1d
from cleandiffuser.utils import report_parameters, DD_RETURN_SCALE
#from cleandiffuser.rl_agent.iql import *
# from pure_iql_mujoco import *
from cleandiffuser.dataset.d4rl_mujoco_dataset import return_reward_range
from cleandiffuser.dataset.d4rl_mujoco_dataset import TwoDirectionD4RLMuJoCoDataset, D4RLMuJoCoDataset, NumpyD4RLMuJoCoDataset
# from cleandiffuser.cleandiffuser.rl_agent.iql import *
def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)



class StateSampler:
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.indices = np.arange(len(self.dataset))
    def sample(self):
        while True:
            batch_indices = np.random.choice(self.indices, size=self.batch_size, replace=False)
            batch_states = self.dataset[batch_indices]
            yield batch_states

class Stitcher:
    def __init__(self, diff_pre, diff_nxt, invdyn, rew_model, dataset, state, args):
        self.diff_1 = diff_pre
        self.diff_2 = diff_nxt
        self.args = args
        self.invdyn = invdyn
        self.rew_model = rew_model
        self.dataset = dataset
        self.candidate_state = self.dataset.normalizers["state"].normalize(state)
        self.state_sampler = StateSampler(self.candidate_state, args.stitch_batch_size)
        self.discount = args.discount ** np.arange((args.task.horizon-1), dtype=np.float32) # TODO
        # self.discount = args.discount ** np.arange(2*(args.task.horizon-1), dtype=np.float32) # TODO
        self.clf = IsolationForest(contamination=0.00005, random_state=42)
        self.clf.fit(self.candidate_state)
        # self.discount = 1.0 ** np.arange(2 * args.horizon-1, dtype=torch.float32) # TODO


    def get_full_data(self, traj_s):
        traj_a = self.invdyn.predict(traj_s[:, :-1, :], traj_s[:, 1:, :])
        traj_r = self.rew_model.predict(traj_s[:, :-1, :], traj_a).squeeze()
        traj_tml =  torch.zeros_like(traj_r).squeeze()
        traj_s, traj_a, traj_r, traj_tml = traj_s[:, :-1, :].detach().cpu().numpy(), traj_a.detach().cpu().numpy(), \
            traj_r.detach().cpu().numpy(), traj_tml.detach().cpu().numpy()
        return (traj_s, traj_a, traj_r, traj_tml)

    def filter_data_ood(self, trajectory_batch: np.ndarray, top_k: int):
        trajectory_batch = trajectory_batch
        batch_size, ob_dim = trajectory_batch.shape[0], trajectory_batch.shape[-1]
        score = self.clf.decision_function(trajectory_batch.reshape(-1, ob_dim)).reshape(batch_size, -1).sum(axis=-1)
        sorted_traj_indices = np.argsort(score)[::-1]
        sorted_traj_indices = sorted_traj_indices[:top_k]
        return torch.from_numpy(trajectory_batch[sorted_traj_indices]).to(self.args.device)

    def filter_data_return(self, traj_s, traj_a, traj_r, traj_tml, top_k):
        scores = (traj_r * self.discount).sum(-1)
        sorted_traj_indices = np.argsort(scores)[::-1]
        sorted_traj_indices = sorted_traj_indices[:top_k]
        return (traj_s[sorted_traj_indices], traj_a[sorted_traj_indices], \
            traj_r[sorted_traj_indices], traj_tml[sorted_traj_indices])

    def generate_sample(self, state_batch, add_noise=0.0):
        if add_noise!=0.0:
            state_batch = state_batch + np.random.normal(loc=0, scale=add_noise, size=state_batch.shape)
        state_batch = torch.from_numpy(state_batch).to(self.args.device)
        batch_size = state_batch.shape[0]
        obs_dim = state_batch.shape[1]
        condition = torch.ones((batch_size, 1), device=self.args.device) * self.args.task.target_return
        prior_1 = torch.zeros((batch_size, self.args.task.horizon, obs_dim), device=self.args.device)
        prior_2 = copy.deepcopy(prior_1)
        # original
        prior_1[:, -1, :]= state_batch
        prior_2[:, 0, :]= state_batch
        traj_pre, log = self.diff_1.sample(
            prior_1, solver=self.args.solver,
            n_samples=batch_size, sample_steps=self.args.sampling_steps, use_ema=self.args.use_ema,
            condition_cfg=condition, w_cfg=self.args.task.w_cfg, temperature=self.args.temperature)
        traj_nxt, log = self.diff_2.sample(
            prior_2, solver=self.args.solver,
            n_samples=batch_size, sample_steps=self.args.sampling_steps, use_ema=self.args.use_ema,
            condition_cfg=condition, w_cfg=self.args.task.w_cfg, temperature=self.args.temperature)

        traj_stitch = torch.concat([traj_pre, traj_nxt[:, 1:, :]], dim=1).detach().cpu().numpy()

        traj_stitch = self.filter_data_ood(traj_stitch, self.args.top_k_ood)
        (traj_s, traj_a, traj_r, traj_tml) = self.get_full_data(traj_stitch)
        (traj_s, traj_a, traj_r, traj_tml) = self.filter_data_return(traj_s, traj_a, traj_r, traj_tml, self.args.top_k_R)
        return traj_s, traj_a, traj_r, traj_tml


    def load(self, save_path):
        self.diff_1.load(save_path + f"pretrain_model/diffusion_pre_ckpt_{self.args.diffusion_ckpt}.pt")
        self.diff_2.load(save_path + f"pretrain_model/diffusion_nxt_ckpt_{self.args.diffusion_ckpt}.pt")
        self.invdyn.load(save_path + f"pretrain_model/invdyn_ckpt_{self.args.diffusion_ckpt}.pt")
        self.rew_model.load(save_path + f"pretrain_model/rew_model_ckpt_{self.args.diffusion_ckpt}.pt")
        print('Load Stitcher Finished!')
        self.diff_1.eval()
        self.diff_2.eval()
        self.invdyn.eval()

def generate_log_name(cfg):
    log_components = []
    for key, value in cfg.items():
        log_components.append(f"{key}_{value}")
    log_name = "_".join(log_components)
    return log_name

your_path = ""
@hydra.main(config_path=f"{your_path}", config_name="mujoco", version_base=None)
def pipeline(args):
    return_scale = DD_RETURN_SCALE[args.task.env_name]

    set_seed(args.seed)
    experiment_id = f'{args.task.env_name}_horizon_{args.task.horizon}'
    save_path = f'{your_path}/{experiment_id}/'

    env = gym.make(args.task.env_name)
    obs_dim, act_dim = env.observation_space.shape[0], env.action_space.shape[0]

    # --------------- Network Architecture -----------------
    nn_diffusion_1 = DiT1d(
        obs_dim, emb_dim=args.emb_dim,
        d_model=args.d_model, n_heads=args.n_heads, depth=args.depth, timestep_emb_type="fourier")
    nn_diffusion_2 = DiT1d(
        obs_dim, emb_dim=args.emb_dim,
        d_model=args.d_model, n_heads=args.n_heads, depth=args.depth, timestep_emb_type="fourier")

    nn_condition_1 = MLPCondition(
        in_dim=1, out_dim=args.emb_dim, hidden_dims=[args.emb_dim, ], act=nn.SiLU(), dropout=args.label_dropout)
    nn_condition_2 = MLPCondition(
        in_dim=1, out_dim=args.emb_dim, hidden_dims=[args.emb_dim, ], act=nn.SiLU(), dropout=args.label_dropout)
    print(f"======================= Parameter Report of Diffusion Model =======================")
    report_parameters(nn_diffusion_1)
    report_parameters(nn_diffusion_2)
    print(f"==============================================================================")

    # ----------------- Masking -------------------
    fix_mask_1 = torch.zeros((args.task.horizon, obs_dim))
    fix_mask_2 = torch.zeros((args.task.horizon, obs_dim))
    fix_mask_1[-1] = 1.
    fix_mask_2[0] = 1.
    loss_weight_1 = torch.ones((args.task.horizon, obs_dim))
    loss_weight_2 = torch.ones((args.task.horizon, obs_dim))
    loss_weight_1[-2] = args.next_obs_loss_weight
    loss_weight_2[1] = args.next_obs_loss_weight

    # --------------- Diffusion Model with Classifier-Free Guidance --------------------

    diff_1 = ContinuousDiffusionSDE(
        nn_diffusion_1, nn_condition_1,
        fix_mask=fix_mask_1, loss_weight=loss_weight_1, ema_rate=args.ema_rate,
        device=args.device, predict_noise=args.predict_noise, noise_schedule="linear")

    diff_2 = ContinuousDiffusionSDE(
        nn_diffusion_2, nn_condition_2,
        fix_mask=fix_mask_2, loss_weight=loss_weight_2, ema_rate=args.ema_rate,
        device=args.device, predict_noise=args.predict_noise, noise_schedule="linear")

    # --------------- Inverse Dynamic -------------------
    invdyn = MlpInvDynamic(obs_dim, act_dim, 512, nn.Tanh(), {"lr": 2e-4}, device=args.device)
    rew_model = RewardPredictor(obs_dim, act_dim, 512, optim_params={"lr": 2e-4}, device=args.device)

    # ---------------------- Training ----------------------
    if args.mode == "train_diffusion":
        if not os.path.exists(save_path+'pretrain_model'):
            os.mkdir(save_path+'pretrain_model')
        dataset = TwoDirectionD4RLMuJoCoDataset(
            env.get_dataset(), horizon=args.task.horizon, terminal_penalty=args.terminal_penalty,
            discount=args.discount)
        dataloader = DataLoader(
            dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

        diffusion_lr_scheduler_1 = CosineAnnealingLR(diff_1.optimizer, args.diffusion_gradient_steps)
        diffusion_lr_scheduler_2 = CosineAnnealingLR(diff_2.optimizer, args.diffusion_gradient_steps)
        invdyn_lr_scheduler = CosineAnnealingLR(invdyn.optim, args.invdyn_gradient_steps)
        rew_model_lr_scheduler = CosineAnnealingLR(rew_model.optim, args.invdyn_gradient_steps)

        diff_1.train()
        diff_2.train()
        invdyn.train()
        rew_model.train()

        n_gradient_step = 0
        log = {"avg_loss_diffusion_1": 0., "avg_loss_diffusion_2": 0.,
                       "avg_loss_invdyn": 0., "avg_loss_rew_model": 0.}

        data_iter = loop_dataloader(dataloader)

        for t in tqdm(range(args.diffusion_gradient_steps),  desc="Training steps", dynamic_ncols=True):
            batch = next(data_iter)

            obs = batch["obs"]["state"].to(args.device)
            act = batch["act"].to(args.device)
            val = batch["val"].to(args.device) / return_scale
            rev_val = batch["rev_val"].to(args.device) / return_scale
            rew = batch["rew"].to(args.device)

            # ----------- Gradient Step ------------
            log["avg_loss_diffusion_1"] += diff_1.update(obs, rev_val)['loss']
            log["avg_loss_diffusion_2"] += diff_2.update(obs, val)['loss']
            # log["avg_loss_diffusion_1"] += diff_1.update(obs, None)['loss']
            # log["avg_loss_diffusion_2"] += diff_2.update(obs, None)['loss']
            diffusion_lr_scheduler_1.step()
            diffusion_lr_scheduler_2.step()
            if n_gradient_step <= args.invdyn_gradient_steps:
                log["avg_loss_invdyn"] += invdyn.update(obs[:, :-1], act[:, :-1], obs[:, 1:])['loss']
                invdyn_lr_scheduler.step()

                log["avg_loss_rew_model"] += rew_model.update(obs, act, rew)['loss']
                rew_model_lr_scheduler.step()

            # ----------- Logging ------------
            if (n_gradient_step + 1) % args.log_interval == 0:
                log["gradient_steps"] = n_gradient_step + 1
                log["avg_loss_diffusion_1"] /= args.log_interval
                log["avg_loss_diffusion_2"] /= args.log_interval
                log["avg_loss_invdyn"] /= args.log_interval
                log["avg_loss_rew_model"] /= args.log_interval
                # print(log)
                log = {"avg_loss_diffusion_1": 0., "avg_loss_diffusion_2": 0.,
                       "avg_loss_invdyn": 0., "avg_loss_rew_model": 0.}

            # ----------- Saving ------------
            if (n_gradient_step + 1) % args.save_interval == 0:
                diff_1.save(save_path + f"pretrain_model/diffusion_pre_ckpt_latest.pt")
                diff_2.save(save_path + f"pretrain_model/diffusion_nxt_ckpt_latest.pt")
                invdyn.save(save_path + f"pretrain_model/invdyn_ckpt_latest.pt")
                rew_model.save(save_path + f"pretrain_model/rew_model_ckpt_latest.pt")

            n_gradient_step += 1
            if n_gradient_step >= args.diffusion_gradient_steps:
                break

    # ---------------------- Stitching ----------------------
    elif args.mode == "stitch":
        # args.task.horizon = 5

        original_dataset = env.get_dataset()
        ori_size = original_dataset['observations'].shape[0]
        dataset = D4RLMuJoCoDataset(
            original_dataset, horizon=args.task.horizon, terminal_penalty=args.terminal_penalty,
            discount=args.discount)
        sti = Stitcher(diff_1, diff_2, invdyn, rew_model, dataset, original_dataset['observations'].astype(np.float32), args)
        sti.load(save_path)
        sti_tra = {'s': [], 'a': [], 'r': [], 'tml': [],}
        for i in tqdm(range(args.stitch_round)):
            traj_s, traj_a, traj_r, traj_tml = sti.generate_sample(next(sti.state_sampler.sample()), args.add_noise)
            sti_tra['s'].append(traj_s)
            sti_tra['a'].append(traj_a)
            sti_tra['r'].append(traj_r)
            sti_tra['tml'].append(traj_tml)

        stitch_data = {
            'observations': np.concatenate(sti_tra['s'], axis=0),
            'actions': np.concatenate(sti_tra['a'], axis=0),
            'rewards': np.concatenate(sti_tra['r'], axis=0),
            'terminals': np.concatenate(sti_tra['tml'], axis=0),
        }
        save_path = f'{save_path}{args.task.env_name}_bitrajdiff_{args.task.horizon - 1 }.pkl'
        print(save_path)
        with open(save_path, 'wb') as f:
            pkl.dump(stitch_data, f)

    else:
        raise ValueError(f"Invalid mode: {args.mode}")


if __name__ == "__main__":
    pipeline()
