import os
import wandb
import numpy as np
import torch
import collections
import pathlib
import tqdm
import dill
import collections
import torch
import numpy as np
import wandb.sdk.data_types.video as wv
import zarr
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner

class SequentialPushTImageRunner(BaseImageRunner):
    def __init__(self,
                 output_dir,
                 n_train=10,
                 n_train_vis=3,
                 train_start_seed=0,
                 n_test=22,
                 n_test_vis=6,
                 legacy_test=False,
                 test_start_seed=10000,
                 max_steps=200,
                 n_obs_steps=8,
                 n_action_steps=8,
                 fps=10,
                 crf=22,
                 render_size=140,
                 past_action=False,
                 tqdm_interval_sec=5.0,
                 n_envs=None):
        super().__init__(output_dir)
        self.media_dir = pathlib.Path(output_dir).joinpath('media')
        self.media_dir.mkdir(parents=True, exist_ok=True)
        # max_steps = 100
        def env_fn(seed, enable_render):
            return MultiStepWrapper(
                VideoRecordingWrapper(
                    PushTImageEnv(
                        legacy=legacy_test,
                        render_size=render_size
                    ),
                    video_recoder=VideoRecorder.create_h264(
                        fps=fps,
                        codec='h264',
                        input_pix_fmt='rgb24',
                        crf=crf,
                        thread_type='FRAME',
                        thread_count=1
                    ),
                    file_path=None,
                    steps_per_render=max(10 // fps, 1)
                ),
                n_obs_steps=n_obs_steps,
                n_action_steps=n_action_steps,
                max_episode_steps=max_steps
            )

        self.envs = []
        self.env_seeds = []
        self.env_prefixs = []
        self.env_init_fn_dills = []
        self.output_dir = output_dir
        self.fps = fps
        self.crf = crf
        self.n_obs_steps = n_obs_steps
        self.n_action_steps = n_action_steps
        self.past_action = past_action
        self.max_steps = max_steps
        self.tqdm_interval_sec = tqdm_interval_sec

        # Setup train environments
        for i in range(n_train):
            seed = train_start_seed + i
            enable_render = i < n_train_vis
            env = env_fn(seed, enable_render)
            self.envs.append(env)
            self.env_seeds.append(seed)
            self.env_prefixs.append('train/')
            self.env_init_fn_dills.append(dill.dumps(lambda e, s=seed: e.seed(s)))

        # Setup test environments
        for i in range(n_test):
            seed = test_start_seed + i
            enable_render = i < n_test_vis
            env = env_fn(seed, enable_render)
            self.envs.append(env)
            self.env_seeds.append(seed)
            self.env_prefixs.append('test/')
            self.env_init_fn_dills.append(dill.dumps(lambda e, s=seed: e.seed(s)))

        self.save_hdf5 = False
        self.perturb = False
        self.output_path = None

    def run(self, policy: BaseImagePolicy, 
            avoid_ood=False, 
            num_samples=40,
            method='mpc'):
        
        device = policy.device
        dtype = policy.dtype

        all_video_paths = []
        all_rewards = []

        if self.save_hdf5:
            # TODO: init zarr file
            if os.path.exists(self.output_path):
                import shutil
                shutil.rmtree(self.output_path)
            zarr_root = zarr.open(self.output_path, mode='w')

            # Create /data and /meta groups
            data_group = zarr_root.create_group('data')
            meta_group = zarr_root.create_group('meta')

            # Setup empty arrays with chunked appendable structure
            zarr_store = {
                'action': data_group.create_dataset('action', shape=(0, 2), chunks=(1024, 2), dtype='f4', overwrite=True, append_dim=0),
                'img': data_group.create_dataset('img', shape=(0, 140, 140, 3), chunks=(128, 140, 140, 3), dtype='f4', overwrite=True, append_dim=0),
                'keypoint': data_group.create_dataset('keypoint', shape=(0, 9, 2), chunks=(1024, 9, 2), dtype='f4', overwrite=True, append_dim=0),
                'n_contacts': data_group.create_dataset('n_contacts', shape=(0, 1), chunks=(1024, 1), dtype='f4', overwrite=True, append_dim=0),
                'state': data_group.create_dataset('state', shape=(0, 5), chunks=(1024, 5), dtype='f4', overwrite=True, append_dim=0),
            }

            # Track total episode length
            episode_ends = []
            total_step_count = 0


        max_rewards = collections.defaultdict(list)
        n_inits = len(self.envs)
        log_data = dict()

        for env, init_fn_dill, seed, prefix in tqdm.tqdm(zip(self.envs, self.env_init_fn_dills, self.env_seeds, self.env_prefixs), 
                                                         total=len(self.envs),
                                                         desc="Running environments sequentially"):
            # Initialize environment
            env.run_dill_function(init_fn_dill)
            if self.save_hdf5:
                # print('type env.env.env:', type(env.env.env))
                env.env.env.save_hdf5 = True
            if self.perturb:
                env.env.env.perturb = True
            obs = env.reset()
            past_action = None
            policy.reset()

            # Dynamically set the video file path
            # video_filename = self.media_dir.joinpath(f'{wv.util.generate_id()}.mp4')
            video_filename = self.media_dir.joinpath(f'{seed}.mp4')
            env.env.file_path = str(video_filename)

            # Initialize variables for recording
            done = False
            rewards = []

            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval Environment {prefix} Seed {seed}", leave=False)
            info = None
            while not done:
                # Prepare observations
                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    np_obs_dict['past_action'] = past_action[:, -(self.n_obs_steps - 1):].astype(np.float32)

                # Add batch dimension to observations
                np_obs_dict = dict_apply(np_obs_dict, lambda x: np.expand_dims(x, axis=0))

                # Convert observations to tensors
                obs_dict = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device=device, dtype=dtype))

                # Get actions from the policy
                if avoid_ood:
                    if 'classifier_guidance' in method:
                        action_dict = policy.predict_action_classifier_guidance(obs_dict)
                    elif method == 'gd':
                        action_dict = policy.predict_action_gd(obs_dict)
                else:
                    action_dict = policy.predict_action(obs_dict)
                
                # Convert actions back to numpy and remove batch dimension
                np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy().squeeze(0))

                action = np_action_dict['action']

                # Step environment
                obs, reward, done, info = env.step(action)
                done = np.all(done)
                past_action = action

                # Record rewards
                rewards.append(reward)
                if reward > 0.98:
                    break
                # break
                pbar.update(1)
            pbar.close()

            if self.save_hdf5:
                # TODO: save each env rollout to zarr file
                traj_len = len(env.env.traj_actions)

                # Append trajectory data
                zarr_store['action'].append(np.asarray(env.env.env.traj_actions, dtype=np.float32))
                # print('type env.env.env.traj_imgs:', type(env.env.env.traj_imgs))
                # print('np.stack(env.env.env.traj_imgs[:traj_len]).astype(np.float32)):', np.stack(env.env.env.traj_imgs[:traj_len]).astype(np.float32).shape)
                zarr_store['img'].append(np.stack(env.env.env.traj_imgs[:traj_len]).astype(np.float32))
                zarr_store['keypoint'].append(np.stack(env.env.env.traj_keypoints[:traj_len]).astype(np.float32))
                zarr_store['n_contacts'].append(np.stack(env.env.env.traj_n_contacts[:traj_len]).astype(np.float32))
                zarr_store['state'].append(np.stack(env.env.env.traj_states[:traj_len]).astype(np.float32))

                total_step_count += traj_len
                episode_ends.append(total_step_count)

            # Ensure video is saved after each episode
            env.env.video_recoder.stop()  # Stop video recorder explicitly to save the file
            video_path = env.env.file_path  # Use the file path that was set for video saving
            all_video_paths.append(video_path)
            all_rewards.append(np.max(rewards))

            # Logging rewards and video paths
            max_reward = np.max(rewards)
            max_rewards[prefix].append(max_reward)
            print(f"Environment {prefix} Seed {seed} Max Reward: {max_reward}")
            log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
            if video_path is not None and pathlib.Path(video_path).exists():
                # Video file exists, prepare video logging
                sim_video = wandb.Video(str(video_path), fps=self.fps)
                log_data[prefix + f'sim_video_{seed}'] = sim_video

        if self.save_hdf5:
            meta_group['episode_ends'] = np.array(episode_ends, dtype=np.int64)
            print(f"Zarr file saved to: {self.output_path}")

        # Aggregate metrics across all environments and log them
        for prefix, value in max_rewards.items():
            name = prefix + 'mean_score'
            value = np.mean(value)
            log_data[name] = value

        return log_data