import wandb
import numpy as np
import torch
import collections
import pathlib
import tqdm
import dill
import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder

from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner

class PushTImageRunner(BaseImageRunner):
    def __init__(self,
            output_dir,
            n_train=10,
            n_train_vis=3,
            train_start_seed=0,
            n_test=22,
            n_test_vis=6,
            legacy_test=False,
            test_start_seed=10000,
            max_steps=200,
            n_obs_steps=8,
            n_action_steps=8,
            fps=10,
            crf=22,
            render_size=96,
            past_action=False,
            tqdm_interval_sec=5.0,
            n_envs=None
        ):
        super().__init__(output_dir)
        if n_envs is None:
            n_envs = n_train + n_test

        steps_per_render = max(10 // fps, 1)
        def env_fn():
            return MultiStepWrapper(
                VideoRecordingWrapper(
                    PushTImageEnv(
                        legacy=legacy_test,
                        render_size=render_size
                    ),
                    video_recoder=VideoRecorder.create_h264(
                        fps=fps,
                        codec='h264',
                        input_pix_fmt='rgb24',
                        crf=crf,
                        thread_type='FRAME',
                        thread_count=1
                    ),
                    file_path=None,
                    steps_per_render=steps_per_render
                ),
                n_obs_steps=n_obs_steps,
                n_action_steps=n_action_steps,
                max_episode_steps=max_steps
            )

        env_fns = [env_fn] * n_envs
        env_seeds = list()
        env_prefixs = list()
        env_init_fn_dills = list()
        # train
        for i in range(n_train):
            seed = train_start_seed + i
            enable_render = i < n_train_vis

            def init_fn(env, seed=seed, enable_render=enable_render):
                # setup rendering
                # video_wrapper
                assert isinstance(env.env, VideoRecordingWrapper)
                env.env.video_recoder.stop()
                env.env.file_path = None
                if enable_render:
                    filename = pathlib.Path(output_dir).joinpath(
                        'media', wv.util.generate_id() + ".mp4")
                    filename.parent.mkdir(parents=False, exist_ok=True)
                    filename = str(filename)
                    env.env.file_path = filename

                # set seed
                assert isinstance(env, MultiStepWrapper)
                env.seed(seed)
            
            env_seeds.append(seed)
            env_prefixs.append('train/')
            env_init_fn_dills.append(dill.dumps(init_fn))

        # test
        for i in range(n_test):
            seed = test_start_seed + i
            enable_render = i < n_test_vis

            def init_fn(env, seed=seed, enable_render=enable_render):
                # setup rendering
                # video_wrapper
                assert isinstance(env.env, VideoRecordingWrapper)
                env.env.video_recoder.stop()
                env.env.file_path = None
                if enable_render:
                    filename = pathlib.Path(output_dir).joinpath(
                        'media', wv.util.generate_id() + ".mp4")
                    filename.parent.mkdir(parents=False, exist_ok=True)
                    filename = str(filename)
                    env.env.file_path = filename

                # set seed
                assert isinstance(env, MultiStepWrapper)
                env.seed(seed)
            
            env_seeds.append(seed)
            env_prefixs.append('test/')
            env_init_fn_dills.append(dill.dumps(init_fn))

        env = AsyncVectorEnv(env_fns)

        # test env
        # env.reset(seed=env_seeds)
        # x = env.step(env.action_space.sample())
        # imgs = env.call('render')
        # import pdb; pdb.set_trace()

        self.env = env
        self.env_fns = env_fns
        self.env_seeds = env_seeds
        self.env_prefixs = env_prefixs
        self.env_init_fn_dills = env_init_fn_dills
        self.fps = fps
        self.crf = crf
        self.n_obs_steps = n_obs_steps
        self.n_action_steps = n_action_steps
        self.past_action = past_action
        self.max_steps = max_steps
        self.tqdm_interval_sec = tqdm_interval_sec

    def run(self, policy: BaseImagePolicy):
        device = policy.device
        dtype = policy.dtype
        env = self.env

        # plan for rollout
        n_envs = len(self.env_fns)
        n_inits = len(self.env_init_fn_dills)
        n_chunks = math.ceil(n_inits / n_envs)

        # allocate data
        all_video_paths = [None] * n_inits
        all_rewards = [None] * n_inits

        for chunk_idx in range(n_chunks):
            start = chunk_idx * n_envs
            end = min(n_inits, start + n_envs)
            this_global_slice = slice(start, end)
            this_n_active_envs = end - start
            this_local_slice = slice(0, this_n_active_envs)

            this_init_fns = self.env_init_fn_dills[this_global_slice]
            n_diff = n_envs - len(this_init_fns)
            if n_diff > 0:
                this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
            assert len(this_init_fns) == n_envs

            # init envs
            env.call_each('run_dill_function',
                          args_list=[(x,) for x in this_init_fns])

            # start rollout
            obs = env.reset()
            past_action = None
            policy.reset()

            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval PushtImageRunner {chunk_idx + 1}/{n_chunks}",
                             leave=False, mininterval=self.tqdm_interval_sec)
            done = False
            while not done:
                # create obs dict
                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    # TODO: not tested
                    np_obs_dict['past_action'] = past_action[
                                                 :, -(self.n_obs_steps - 1):].astype(np.float32)

                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                # run policy
                with torch.no_grad():
                    action_dict = policy.predict_action(obs_dict)

                # device_transfer
                np_action_dict = dict_apply(action_dict,
                                            lambda x: x.detach().to('cpu').numpy())

                action = np_action_dict['action']

                # step env
                obs, reward, done, info = env.step(action)
                done = np.all(done)
                past_action = action

                # update pbar
                pbar.update(action.shape[1])
            pbar.close()

            all_video_paths[this_global_slice] = env.render()[this_local_slice]
            all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
        # clear out video buffer
        _ = env.reset()

        # log
        max_rewards = collections.defaultdict(list)
        log_data = dict()
        # results reported in the paper are generated using the commented out line below
        # which will only report and average metrics from first n_envs initial condition and seeds
        # fortunately this won't invalidate our conclusion since
        # 1. This bug only affects the variance of metrics, not their mean
        # 2. All baseline methods are evaluated using the same code
        # to completely reproduce reported numbers, uncomment this line:
        # for i in range(len(self.env_fns)):
        # and comment out this line
        for i in range(n_inits):
            seed = self.env_seeds[i]
            prefix = self.env_prefixs[i]
            max_reward = np.max(all_rewards[i])
            max_rewards[prefix].append(max_reward)
            log_data[prefix + f'sim_max_reward_{seed}'] = max_reward

            # visualize sim
            video_path = all_video_paths[i]
            if video_path is not None:
                sim_video = wandb.Video(video_path)
                log_data[prefix + f'sim_video_{seed}'] = sim_video

        # log aggregate metrics
        for prefix, value in max_rewards.items():
            name = prefix + 'mean_score'
            value = np.mean(value)
            log_data[name] = value

        return log_data

    def augwog_run(self, policy: BaseImagePolicy):
        device = policy.device
        dtype = policy.dtype
        env = self.env

        # plan for rollout
        n_envs = len(self.env_fns)
        n_inits = len(self.env_init_fn_dills)
        n_chunks = math.ceil(n_inits / n_envs)

        # allocate data
        all_video_paths = [None] * n_inits
        all_rewards = [None] * n_inits
        save_data_list = [dict() for _ in range(n_chunks)]

        for chunk_idx in range(n_chunks):
            save_data = save_data_list[chunk_idx]

            start = chunk_idx * n_envs
            end = min(n_inits, start + n_envs)
            this_global_slice = slice(start, end)
            this_n_active_envs = end - start
            this_local_slice = slice(0, this_n_active_envs)

            this_init_fns = self.env_init_fn_dills[this_global_slice]
            n_diff = n_envs - len(this_init_fns)
            if n_diff > 0:
                this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
            assert len(this_init_fns) == n_envs

            # init envs
            env.call_each('run_dill_function',
                          args_list=[(x,) for x in this_init_fns])

            # start rollout
            obs = env.reset()
            past_action = None
            policy.reset()

            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval PushtImageRunner {chunk_idx + 1}/{n_chunks}",
                             leave=False, mininterval=self.tqdm_interval_sec)
            done = False
            while not done:
                # create obs dict
                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    # TODO: not tested
                    np_obs_dict['past_action'] = past_action[
                                                 :, -(self.n_obs_steps - 1):].astype(np.float32)

                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                # run policy
                with torch.no_grad():
                    action_dict = policy.predict_action(obs_dict)

                # device_transfer
                np_action_dict = dict_apply(action_dict,
                                            lambda x: x.detach().to('cpu').numpy())

                action = np_action_dict['action']


                if 'obs' not in save_data:
                    save_data['obs'] = dict()
                    for key in obs_dict.keys():
                        save_data['obs'][key] = obs_dict[key][:,-1:].cpu() # [14, 1, 3, 84, 84])
                else:
                    # for key in save_data['obs'].keys():
                    #     print(key)
                    for key in obs_dict.keys():
                        save_data['obs'][key] = torch.cat((save_data['obs'][key], obs_dict[key][:,-1:].cpu()), dim=1)

                if 'action' not in save_data:
                    save_data['action'] = action_dict['action_pred'][:,1:2].cpu()
                else:
                    save_data['action'] = torch.cat((save_data['action'], action_dict['action_pred'][:,1:2].cpu()), dim=1)  # [14, 1, 20]

                env_action = np_action_dict['action_pred'][:,1:2]





                # step env
                obs, reward, done, info = env.step(env_action)
                done = np.all(done)
                past_action = action


                if 'reward' not in save_data:
                    save_data['reward'] = torch.from_numpy(reward).unsqueeze(1)
                else:
                    save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1)), dim=1)


                # update pbar
                pbar.update(action.shape[1])
            pbar.close()

            all_video_paths[this_global_slice] = env.render()[this_local_slice]
            all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
        # clear out video buffer
        _ = env.reset()

        # log
        max_rewards = collections.defaultdict(list)
        log_data = dict()
        # results reported in the paper are generated using the commented out line below
        # which will only report and average metrics from first n_envs initial condition and seeds
        # fortunately this won't invalidate our conclusion since
        # 1. This bug only affects the variance of metrics, not their mean
        # 2. All baseline methods are evaluated using the same code
        # to completely reproduce reported numbers, uncomment this line:
        # for i in range(len(self.env_fns)):
        # and comment out this line
        # for i in range(n_inits):
        #     seed = self.env_seeds[i]
        #     prefix = self.env_prefixs[i]
        #     max_reward = np.max(all_rewards[i])
        #     max_rewards[prefix].append(max_reward)
        #     log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
        #
        #     # visualize sim
        #     video_path = all_video_paths[i]
        #     if video_path is not None:
        #         sim_video = wandb.Video(video_path)
        #         log_data[prefix + f'sim_video_{seed}'] = sim_video
        #
        # # log aggregate metrics
        # for prefix, value in max_rewards.items():
        #     name = prefix + 'mean_score'
        #     value = np.mean(value)
        #     log_data[name] = value
        traj_list = []
        # save_data['obs']['robot0_eye_in_hand_image'].shape
        # torch.Size([14, 8, 2, 3, 84, 84])
        for j in range(n_chunks):
            # save_data_list = [dict() for _ in range(n_chunks)]
            save_data = save_data_list[j]
            for i in range(env_number):
                reward_tensor = save_data['reward'][i]  # traj_len
                if torch.any(reward_tensor == 1):
                    position = (reward_tensor == 1).nonzero(as_tuple=True)[0][0]
                    if position + self.n_action_steps < save_data['action'].shape[1]:
                        traj_dict = {'actions': save_data['action'][i, :position + self.n_action_steps].numpy(), 'obs': {}}   # traj,3,84,84
                    else:
                        traj_dict = {'actions': save_data['action'][i, :].numpy(), 'obs': {}}
                    for key in save_data['obs'].keys():
                        if position + self.n_action_steps < save_data['obs'][key].shape[1]:
                            traj_dict['obs'][key] = save_data['obs'][key][i, :position + self.n_action_steps].numpy() # traj,20
                        else:
                            traj_dict['obs'][key] = save_data['obs'][key][i, :].numpy()
                    traj_list.append(traj_dict)
        return log_data, traj_list

        # return log_data,traj
    def aug_run(self, s_policy,policy: BaseImagePolicy,current_epoch):
        device = policy.device
        dtype = policy.dtype
        env = self.env

        # plan for rollout
        n_envs = len(self.env_fns)
        n_inits = len(self.env_init_fn_dills)
        n_chunks = math.ceil(n_inits / n_envs)

        # allocate data
        all_video_paths = [None] * n_inits
        all_rewards = [None] * n_inits
        save_data_list = [dict() for _ in range(n_chunks)]

        for chunk_idx in range(n_chunks):
            save_data = save_data_list[chunk_idx]

            start = chunk_idx * n_envs
            end = min(n_inits, start + n_envs)
            this_global_slice = slice(start, end)
            this_n_active_envs = end - start
            this_local_slice = slice(0, this_n_active_envs)

            this_init_fns = self.env_init_fn_dills[this_global_slice]
            n_diff = n_envs - len(this_init_fns)
            if n_diff > 0:
                this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
            assert len(this_init_fns) == n_envs

            # init envs
            env.call_each('run_dill_function',
                          args_list=[(x,) for x in this_init_fns])

            # start rollout
            obs = env.reset()
            past_action = None
            policy.reset()

            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval PushtImageRunner {chunk_idx + 1}/{n_chunks}",
                             leave=False, mininterval=self.tqdm_interval_sec)
            done = False
            while not done:
                # create obs dict
                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    # TODO: not tested
                    np_obs_dict['past_action'] = past_action[
                                                 :, -(self.n_obs_steps - 1):].astype(np.float32)

                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                # run policy
                with torch.no_grad():
                    s_action_dict = s_policy.predict_action(obs_dict)
                    action_dict = policy.predict_action_mse_guide(obs_dict,s_action_dict['action'],current_epoch)

                # device_transfer
                np_action_dict = dict_apply(action_dict,
                                            lambda x: x.detach().to('cpu').numpy())

                action = np_action_dict['action']


                if 'obs' not in save_data:
                    save_data['obs'] = dict()
                    for key in obs_dict.keys():
                        save_data['obs'][key] = obs_dict[key][:,-1:].cpu() # [14, 1, 3, 84, 84])
                else:
                    # for key in save_data['obs'].keys():
                    #     print(key)
                    for key in obs_dict.keys():
                        save_data['obs'][key] = torch.cat((save_data['obs'][key], obs_dict[key][:,-1:].cpu()), dim=1)

                if 'action' not in save_data:
                    save_data['action'] = action_dict['action_pred'][:,1:2].cpu()
                else:
                    save_data['action'] = torch.cat((save_data['action'], action_dict['action_pred'][:,1:2].cpu()), dim=1)  # [14, 1, 20]

                env_action = np_action_dict['action_pred'][:,1:2]





                # step env
                obs, reward, done, info = env.step(env_action)
                done = np.all(done)
                past_action = action


                if 'reward' not in save_data:
                    save_data['reward'] = torch.from_numpy(reward).unsqueeze(1)
                else:
                    save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1)), dim=1)


                # update pbar
                pbar.update(action.shape[1])
            pbar.close()

            all_video_paths[this_global_slice] = env.render()[this_local_slice]
            all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
        # clear out video buffer
        _ = env.reset()

        # log
        max_rewards = collections.defaultdict(list)
        log_data = dict()
        # results reported in the paper are generated using the commented out line below
        # which will only report and average metrics from first n_envs initial condition and seeds
        # fortunately this won't invalidate our conclusion since
        # 1. This bug only affects the variance of metrics, not their mean
        # 2. All baseline methods are evaluated using the same code
        # to completely reproduce reported numbers, uncomment this line:
        # for i in range(len(self.env_fns)):
        # and comment out this line
        # for i in range(n_inits):
        #     seed = self.env_seeds[i]
        #     prefix = self.env_prefixs[i]
        #     max_reward = np.max(all_rewards[i])
        #     max_rewards[prefix].append(max_reward)
        #     log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
        #
        #     # visualize sim
        #     video_path = all_video_paths[i]
        #     if video_path is not None:
        #         sim_video = wandb.Video(video_path)
        #         log_data[prefix + f'sim_video_{seed}'] = sim_video
        #
        # # log aggregate metrics
        # for prefix, value in max_rewards.items():
        #     name = prefix + 'mean_score'
        #     value = np.mean(value)
        #     log_data[name] = value
        traj_list = []
        # save_data['obs']['robot0_eye_in_hand_image'].shape
        # torch.Size([14, 8, 2, 3, 84, 84])
        for j in range(n_chunks):
            # save_data_list = [dict() for _ in range(n_chunks)]
            save_data = save_data_list[j]
            for i in range(env_number):
                reward_tensor = save_data['reward'][i]  # traj_len
                if torch.any(reward_tensor == 1):
                    position = (reward_tensor == 1).nonzero(as_tuple=True)[0][0]
                    if position + self.n_action_steps < save_data['action'].shape[1]:
                        traj_dict = {'actions': save_data['action'][i, :position + self.n_action_steps].numpy(), 'obs': {}}   # traj,3,84,84
                    else:
                        traj_dict = {'actions': save_data['action'][i, :].numpy(), 'obs': {}}
                    for key in save_data['obs'].keys():
                        if position + self.n_action_steps < save_data['obs'][key].shape[1]:
                            traj_dict['obs'][key] = save_data['obs'][key][i, :position + self.n_action_steps].numpy() # traj,20
                        else:
                            traj_dict['obs'][key] = save_data['obs'][key][i, :].numpy()
                    traj_list.append(traj_dict)
        return log_data, traj_list

        # return log_data,traj
