import os
import h5py
import torch as th
import numpy as np
from PIL import Image

############## DataBatch ##############
class OfflineDataBatch():
    def __init__(self, data, batch_size, max_seq_length, device='cpu') -> None:
        self.data = data
        self.batch_size = batch_size
        self.max_seq_length = max_seq_length # None if taken all length
        self.device = device
        for k, v in self.data.items():
            # (batch_size, T, n_agents, *shape)
            # truncate here, interface directly in offlinebuffer
            self.data[k] = v[:, :max_seq_length].to(self.device)
    
    def __getitem__(self, item):
        if isinstance(item, str):
            if item in self.data:
                return self.data[item]
            elif hasattr(self, item):
                return getattr(self, item)
            else:
                raise ValueError('Cannot index OfflineDataBatch with key "{}"'.format(item))
        else:
            raise ValueError('Cannot index OfflineDataBatch with key "{}"'.format(item))

    def to(self, device=None):
        if device is None:
            device = self.device
        for k, v in self.data.items():
            self.data[k] = v.to(device)
        self.device = device # update self.device
    
    def keys(self):
        return list(self.data.keys())
    
    def assign(self, key, value):
        if key in self.data:
            assert 0, "Cannot assign to existing key"
        self.data[key] = value

# StarCraft 2 Unit Visualization Specification:
# 1. Unit Core Visualization
# Main unit circle:
# - Ally units: Green
# - Enemy units: Red
# - Health shown via darker shading when damaged
# - Black outline for visibility
# Range indicators:
# - Sight range: Cyan circle
# - Attack range: Dashed red circle
# 2. Unit Status Display System:
# Concentric Status Arcs (from outer to inner):
# - [Shield] - Blue arc at radius-0.05 
# - [Energy] - Red arc at radius-0.1
# - [Orders] - Cyan arc at radius-0.15  
# - [Buffs]  - White arc at radius-0.2
# - [Upgrades] - Colored arcs at radius-0.25
# 3. Movement & Target System
# - Facing indicator: White arc showing direction
# - Attack visualization: Cyan lines connecting ally and the targets that ally attacks
# 4. Information Display
# - Unit names: White text above units
# - Game overlay: Score, steps, time, APM in top-right

############## OfflineBuffer ##############
class OfflineBufferH5(): # One Task
    def __init__(self, args, algo_args, env_args) -> None:
        self.args = args
        self.algo_args = algo_args
        self.env_args = env_args
        self.base_data_folder = env_args["skill"]["skill_configs"]["skill_dataset_path"]
        self.map_name = env_args["map_name"]
        self.imgs_path = "./res/smacv2/dataset/imgs/episode_records"
        self.episode_idx = 0
        self.h5_files = []
        for root, dirs, files in os.walk(self.base_data_folder):
            for file in files:
                if file.endswith(".h5"):
                    self.h5_files.append(os.path.join(root, file))
        print(self.h5_files)

        self.max_buffer_size = algo_args["offline"]["max_buffer_size"]
        self.shuffle = algo_args["offline"]["shuffle"]
        dataset = []
        for i in range(len(self.h5_files)):
            dataset.extend(self._read_imgs(self.h5_files[i]))
            # dataset.extend(self._read_data(self.h5_files[i], self.shuffle))
        # lengths = [d["actions"].shape[0] for d in dataset]
        # # Analyze lengths
        # print(f"Total episodes: {len(lengths)}")
        # print(f"Mean length: {np.mean(lengths):.2f}")
        # print(f"Min length: {np.min(lengths)}")
        # print(f"Max length: {np.max(lengths)}")

        # # Filter episodes based on length criteria
        # mean_length = np.mean(lengths)
        # std_length = np.std(lengths)
        # z_scores = (lengths - mean_length) / std_length
        # std_threshold = algo_args["offline"].get("std_threshold", 2.0)  # Set threshold for length std devs below mean
        # valid_indices = [i for i, z_score in enumerate(z_scores) if -std_threshold <= z_score <= -std_threshold//2]  # Keep episodes within 1 std below mean
        
        # print(f"Kept {len(valid_indices)/len(lengths)*100:.1f}% of episodes")
        # dataset = [dataset[i] for i in valid_indices]
        # # self.data = {
        # #     k: np.concatenate([v[k] for v in dataset], axis=0) for k in dataset[0].keys()
        # # }
        # # self.keys = list(self.data.keys())
        # # self.buffer_size = self.data[self.keys[0]].shape[0]
        self.data = dataset
        self.buffer_size = len(self.data)
        lengths = [len(d) for d in dataset]
        mean_length = np.mean(lengths)
        std_length = np.std(lengths)

        # if self.shuffle:
        #     # shuffle again
        #     shuffled_idx = np.random.choice(self.buffer_size, self.buffer_size, replace=False)
        #     self.data = [self.data[i] for i in shuffled_idx]
        #     # self.data = {k: v[shuffled_idx] for k, v in self.data.items()} 

    def _read_data(self, h5_path, shuffle):
        # Load data_batch from HDF5
        loaded_data_batch = []
        # Extract map name from h5_path by splitting on directories and removing underscore prefix
        map_name = h5_path.split("/")[-2]  # Gets "protoss_5_vs_5" from the path
        with h5py.File(h5_path, "r") as f:
            for episode_name in f.keys():
                episode_data = {}
                for key in f[episode_name].keys():
                    dataset = f[episode_name][key]
                    if dataset.dtype.kind == "S":  # Handle byte strings
                        episode_data[key] = dataset.asstr()[:]
                    else:
                        episode_data[key] = np.array(dataset)
                loaded_data_batch.append(episode_data)

        lengths = [d["actions"].shape[0] for d in loaded_data_batch]
        # Analyze lengths
        print(f"{map_name}- Total episodes: {len(lengths)}")
        print(f"{map_name}- Mean length: {np.mean(lengths):.2f}")
        print(f"{map_name}- Min length: {np.min(lengths)}")
        print(f"{map_name}- Max length: {np.max(lengths)}")

        # Sort lengths and keep only the shortest 1%
        num_episodes_to_keep = max(1, int(len(lengths) * 0.01))  # Keep at least 1 episode
        sorted_indices = np.argsort(lengths)[:num_episodes_to_keep]
        loaded_data_batch = [loaded_data_batch[i] for i in sorted_indices]

        print(f"{map_name}- Kept shortest {num_episodes_to_keep} episodes ({1:.1f}% of total)")
        print(f"{map_name}- New mean length: {np.mean([len(d['actions']) for d in loaded_data_batch]):.2f}")

        data_size = len(loaded_data_batch)

        if shuffle:
            shuffled_idx = np.random.choice(data_size, data_size, replace=False)
            data = [loaded_data_batch[i] for i in shuffled_idx]
        return data
    
    def _read_imgs(self, h5_path):
        # Load data_batch from HDF5
        loaded_data_batch = []
        # Extract map name from h5_path by splitting on directories and removing underscore prefix
        map_name = h5_path.split("/")[-2]  # Gets "protoss_5_vs_5" from the path
        with h5py.File(h5_path, "r") as f:
            for episode_name in f.keys():
                episode_data = []
                episode_path = os.path.join(self.imgs_path, "episode_{}".format(self.episode_idx))
                os.makedirs(episode_path, exist_ok=True)
                for key in f[episode_name].keys():
                    dataset = f[episode_name][key]                    
                    for frame_idx, frame in enumerate(dataset):
                        frame_img = np.array(frame)
                        frame_img = Image.fromarray(frame_img)
                        screen_image_filename = os.path.join(episode_path, "step_{}.jpg".format(frame_idx))
                        frame_img.save(screen_image_filename)
                        # Store the paths instead of the actual images
                        episode_data.append(screen_image_filename)
                loaded_data_batch.append(episode_data)
                self.episode_idx += 1

        lengths = [len(d) for d in loaded_data_batch]
        # Analyze lengths
        print(f"{map_name}- Total episodes: {len(lengths)}")
        print(f"{map_name}- Mean length: {np.mean(lengths):.2f}")
        print(f"{map_name}- Min length: {np.min(lengths)}")
        print(f"{map_name}- Max length: {np.max(lengths)}")

        # # Sort lengths and keep only the shortest 1%
        # num_episodes_to_keep = max(1, int(len(lengths) * 0.01))  # Keep at least 1 episode
        # sorted_indices = np.argsort(lengths)[:num_episodes_to_keep]
        # loaded_data_batch = [loaded_data_batch[i] for i in sorted_indices]

        # print(f"{map_name}- Kept shortest {num_episodes_to_keep} episodes ({1:.1f}% of total)")
        # print(f"{map_name}- New mean length: {np.mean([len(d['actions']) for d in loaded_data_batch]):.2f}")

        # data_size = len(loaded_data_batch)

        return loaded_data_batch

    @staticmethod
    def max_t_filled(filled):
        return th.sum(filled, 1).max(0)[0]
    
    def can_sample(self, batch_size):
        return self.buffer_size >= batch_size

    def sample(self, batch_size):
        sampled_ep_idx = np.random.choice(self.buffer_size, batch_size, replace=False)
        
        sampled_data = {k: th.tensor(v[sampled_ep_idx]) for k, v in self.data.items()}
        if self.args.use_corrected_terminated and "corrected_terminated" in sampled_data:
            sampled_data["terminated"] = sampled_data["corrected_terminated"]
        """sampled_data = {}
        for k, v in self.data.items():
            dtype = self.scheme[k].get("dtype", th.float32) if self.scheme is not None and k in self.scheme else th.float32
            sampled_data[k] = th.tensor(v[sampled_ep_idx], dtype=dtype)"""
            
        max_ep_t = self.max_t_filled(filled=sampled_data['filled']).item()
        offline_data_batch = OfflineDataBatch(data=sampled_data, 
                                              batch_size=batch_size, 
                                              max_seq_length=max_ep_t, 
                                              device=self.device)
        return offline_data_batch


class OfflineBuffer():
    def __init__(self, args, algo_args, env_args) -> None:

        if algo_args["offline"]["offline_data_type"]=="h5":
            self.buffer = OfflineBufferH5(args, algo_args, env_args)
            self.buffer_size = self.buffer.buffer_size
        else:
            raise NotImplementedError("Do not support offline data type: {}".format(algo_args["offline"]["offline_data_type"]))
    
    def can_sample(self, batch_size):
        return self.buffer.can_sample(batch_size)

    def sample(self, batch_size):
        return self.buffer.sample(batch_size)

    def sequential_iter(self, batch_size):
        return self.buffer.sequential_iter(batch_size)

    def reset_sequential_iter(self):
        self.buffer.reset_sequential_iter()

class DataSaver():
    def __init__(self, save_path, logger=None, max_size=2000) -> None:
        self.save_path = save_path
        self.max_size = max_size
        #self.episode_batch = []
        self.data_batch = []
        self.cur_size = 0
        self.part_cnt = 0
        self.logger = logger
        os.makedirs(save_path, exist_ok=True)
    
    def append(self, data):
        self.data_batch.append(data) # data \in OfflineDataBatch/EpisodeBatch
        self.cur_size += data[list(data.keys())[0]].shape[0]
        #if len(self.episode_batch) >= self.max_size:
        if self.cur_size >= self.max_size:
            self.save_batch()
    
    def save_batch(self):
        #if len(self.data_batch) == 0:
        if self.cur_size == 0:
            return
        
        save_file = os.path.join(self.save_path, "part_{}.h5".format(self.part_cnt))

        with h5py.File(save_file, "w") as f:
            for idx, episode_data in enumerate(self.data_batch):
                episode_group = f.create_group(f"episode_{idx}")
                for key, value in episode_data.items():
                    # Store arrays directly
                    if value.dtype.kind == "U":  # Handle Unicode strings
                        converted_value = value.astype('S')
                        dt = h5py.special_dtype(vlen=str)
                        episode_group.create_dataset(key, data=converted_value, dtype=dt)
                    else:
                        episode_group.create_dataset(key, data=value)

        # keys = list(self.data_batch[0].keys())
        # data_dict = {k: [] for k in keys}
        # for data in self.data_batch:
        #     for k in keys:
        #         if isinstance(data[k], th.Tensor):
        #             data_dict[k].append(data[k].numpy())
        #         else:
        #             data_dict[k].append(data[k])
                    
        # # concatenate e.g. [(x, T, n_agents, *shape), ...] -> [max_size, T, n_agents, *shape]
        # data_dict = {k: np.concatenate(v) for k, v in data_dict.items()}
        # save_file = os.path.join(self.save_path, "part_{}.h5".format(self.part_cnt))
        # with h5py.File(save_file, 'w') as file:
        #     for k, v in data_dict.items():
        #         file.create_dataset(k, data=v, compression='gzip', compression_opts=9)
        if self.logger is not None:
            self.logger.console_logger.info("Save offline buffer to {} with {} episodes".format(save_file, self.cur_size))
        else:
            print("Save offline buffer to {} with {} episodes".format(save_file, self.cur_size))
        self.data_batch.clear()
        self.cur_size = 0
        self.part_cnt += 1
    
    def close(self):
        self.save_batch()
