import wandb
import numpy as np
import torch
import collections
import pathlib
import tqdm
import dill
import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder

from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner

class SequentialBlockPushImageRunner(BaseImageRunner):
    def __init__(self,
                 output_dir,
                 n_train=10,
                 n_train_vis=3,
                 train_start_seed=0,
                 n_test=22,
                 n_test_vis=6,
                 test_start_seed=10000,
                 max_steps=200,
                 n_obs_steps=8,
                 n_action_steps=8,
                 fps=5,
                 crf=22,
                 past_action=False,
                 abs_action=False,
                 obs_eef_target=True,
                 eef_only=False,
                 tqdm_interval_sec=5.0,
                 n_envs=None):
        super().__init__(output_dir)
        self.media_dir = pathlib.Path(output_dir).joinpath('media')
        self.media_dir.mkdir(parents=True, exist_ok=True)
        # Define environment creation function
        def env_fn(seed, enable_render):
            return MultiStepWrapper(
                VideoRecordingWrapper(
                    BlockPushMultimodal(
                        control_frequency=10,  # task_fps
                        shared_memory=False,
                        seed=seed,
                        abs_action=abs_action,
                        image_size=(240, 320)
                    ),
                    video_recoder=VideoRecorder.create_h264(
                        fps=fps,
                        codec='h264',
                        input_pix_fmt='rgb24',
                        crf=crf,
                        thread_type='FRAME',
                        thread_count=1
                    ),
                    file_path=None,
                    steps_per_render=max(10 // fps, 1)
                ),
                n_obs_steps=n_obs_steps,
                n_action_steps=n_action_steps,
                max_episode_steps=max_steps
            )

        self.envs = []
        self.env_seeds = []
        self.env_prefixs = []
        self.env_init_fn_dills = []
        self.output_dir = output_dir
        self.fps = fps
        self.crf = crf
        self.n_obs_steps = n_obs_steps
        self.n_action_steps = n_action_steps
        self.past_action = past_action
        self.max_steps = max_steps
        self.tqdm_interval_sec = tqdm_interval_sec
        self.abs_action = abs_action
        self.obs_eef_target = obs_eef_target
        self.eef_only = eef_only

        # Setup train environments
        for i in range(n_train):
            seed = train_start_seed + i
            enable_render = i < n_train_vis
            env = env_fn(seed, enable_render)
            self.envs.append(env)
            self.env_seeds.append(seed)
            self.env_prefixs.append('train/')
            self.env_init_fn_dills.append(dill.dumps(lambda e, s=seed: e.seed(s)))

        # Setup test environments
        for i in range(n_test):
            seed = test_start_seed + i
            enable_render = i < n_test_vis
            env = env_fn(seed, enable_render)
            self.envs.append(env)
            self.env_seeds.append(seed)
            self.env_prefixs.append('test/')
            self.env_init_fn_dills.append(dill.dumps(lambda e, s=seed: e.seed(s)))

    def run(self, policy: BaseImagePolicy, 
            avoid_ood=False, 
            h_step=False, 
            optim_lr=1e-4, 
            num_iters=20, 
            use_embed=False, 
            weight_decay=0.0, 
            input_type='state', 
            use_history=False):
        
        device = policy.device
        dtype = policy.dtype

        all_video_paths = []
        all_rewards = []
        last_infos = []

        for env, init_fn_dill, seed, prefix in tqdm.tqdm(zip(self.envs, self.env_init_fn_dills, self.env_seeds, self.env_prefixs), 
                                                         total=len(self.envs),
                                                         desc="Running environments sequentially"):
            # Initialize environment
            env.run_dill_function(init_fn_dill)
            obs = env.reset()
            past_action = None
            policy.reset()

            # Dynamically set the video file path
            video_filename = self.media_dir.joinpath(f'{prefix.strip("/")}_seed_{seed}.mp4')
            video_filename.parent.mkdir(parents=True, exist_ok=True)
            env.env.file_path = str(video_filename)

            # Initialize variables for recording
            done = False
            rewards = []
            last_info = None

            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval BlockPushSequentialRunner {prefix} Seed {seed}", leave=False)
            idx = 0
            while not done:
                # Create observation dictionary
                # if idx == 5: break
                idx += 1
                img = obs.pop('rgb', None)
                if img is not None:
                    img = np.moveaxis(img.astype(np.float32) / 255, -1, 1)
                state = np.concatenate(list(obs.values()), axis=1)
                if not self.obs_eef_target:
                    state[..., 8:10] = 0
                if self.eef_only:
                    state = state[..., 6:10]
                np_obs_dict = {
                    'state': state.astype(np.float32),
                    'image': img.astype(np.float32) if img is not None else None
                }
                if self.past_action and (past_action is not None):
                    np_obs_dict['past_action'] = past_action[:, -(self.n_obs_steps - 1):].astype(np.float32)

                # Add batch dimension to observations
                np_obs_dict = dict_apply(np_obs_dict, lambda x: np.expand_dims(x, axis=0))

                # Convert observations to tensors
                obs_dict = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device=device, dtype=dtype))

                # Get actions from the policy
                action_dict = policy.predict_action(obs_dict, avoid_ood=avoid_ood, 
                                                    h_step=h_step, 
                                                    optim_lr=optim_lr, 
                                                    num_iters=num_iters, 
                                                    use_embed=use_embed, 
                                                    weight_decay=weight_decay, 
                                                    input_type=input_type,
                                                    use_history=use_history,
                                                    info=last_info)
                
                # Convert actions back to numpy and remove batch dimension
                np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy().squeeze(0))
                
                # Ensure 'action' key exists
                if 'action' not in np_action_dict:
                    raise KeyError("The action dictionary returned by the policy does not contain the 'action' key.")

                action = np_action_dict['action']

                # Step environment
                obs, reward, done, info = env.step(action)
                done = np.all(done)
                past_action = action

                # Record rewards
                rewards.append(reward)
                last_info = info

                if reward > 0.98:
                    break
                pbar.update(1)
            pbar.close()

            # Ensure video is saved after each episode
            env.env.video_recoder.stop()  # Stop video recorder explicitly to save the file
            video_path = env.env.file_path  # Use the file path that was set for video saving
            all_video_paths.append(video_path)
            all_rewards.append(rewards)
            last_infos.append(dict((k,v[-1]) for k, v in last_info.items()))


        total_rewards = collections.defaultdict(list)
        total_p1 = collections.defaultdict(list)
        total_p2 = collections.defaultdict(list)
        prefix_event_counts = collections.defaultdict(lambda: collections.defaultdict(int))
        prefix_counts = collections.defaultdict(int)

        log_data = {}

        # After all environments have been run, process logging
        n_inits = len(self.envs)
        for i in range(n_inits):
            seed = self.env_seeds[i]
            prefix = self.env_prefixs[i]
            this_rewards = all_rewards[i]
            total_reward = np.unique(this_rewards).sum()  # Sum of unique rewards
            p1 = total_reward > 0.4
            p2 = total_reward > 0.9

            total_rewards[prefix].append(total_reward)
            total_p1[prefix].append(p1)
            total_p2[prefix].append(p2)
            log_data[prefix + f'sim_max_reward_{seed}'] = total_reward

            # Aggregate event counts
            prefix_counts[prefix] += 1
            if last_infos[i] is not None:
                for key, value in last_infos[i].items():
                    # print('key {} val shape {}'.format(key, value.shape))
                    delta_count = 1 if value > 0 else 0
                    prefix_event_counts[prefix][key] += delta_count

            # Visualize simulation by logging video
            video_path = all_video_paths[i]
            if video_path is not None and pathlib.Path(video_path).exists():
                sim_video = wandb.Video(str(video_path), fps=self.fps)
                log_data[prefix + f'sim_video_{seed}'] = sim_video

        # Log aggregate metrics
        for prefix, rewards in total_rewards.items():
            name = prefix + 'mean_score'
            value = np.mean(rewards)
            log_data[name] = value
        for prefix, p1_flags in total_p1.items():
            name = prefix + 'p1'
            value = np.mean(p1_flags)
            log_data[name] = value
        for prefix, p2_flags in total_p2.items():
            name = prefix + 'p2'
            value = np.mean(p2_flags)
            log_data[name] = value
        
        # Summarize probabilities for events
        for prefix, events in prefix_event_counts.items():
            prefix_count = prefix_counts[prefix]
            for event, count in events.items():
                prob = count / prefix_count
                key = prefix + event
                log_data[key] = prob

        return log_data