import os
import pickle
import time
import torch
import numpy as np
import torch.distributed as dist
import threading
from queue import Queue
import cv2


# an il dataset for distillation from a pretrained agent

class Record(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.metadata = {}

    def __getitem__(self, key):
        if isinstance(key, str):
            return super().__getitem__(key)
        elif isinstance(key, slice):
            ret = Record({k: v[key] for k, v in self.items()})
            ret.metadata = self.metadata.copy()
            return ret
        else:
            raise KeyError(f"Invalid key type: {type(key)}. Expected str or slice.")
    
    def to(self, device):
        """Move all tensors in the record to the specified device."""
        for key, value in self.items():
            if isinstance(value, torch.Tensor):
                self[key] = value.to(device)
        return self
    
    def set_metadata(self, **kwargs):
        """Set metadata for the record."""
        self.metadata.update(kwargs)
        return self
    
    @classmethod
    def stack(cls, records, dim=0):
        """Stack a list of records into a single record."""
        if not records:
            return cls()
        
        stacked = cls()
        for key in records[0].keys():
            stacked[key] = torch.stack([record[key] for record in records], dim=dim)
        
        # Combine metadata from all records
        stacked.metadata = {"batch_size": len(records)}
        return stacked

class Dataset:
    def __init__(self, data_root, image_size=None):
        self.data_root = data_root
        self.observation_space = pickle.load(
            open(os.path.join(data_root, "observation_space.pkl"), "rb")
        )
        self.action_space = pickle.load(
            open(os.path.join(data_root, "action_space.pkl"), "rb")
        )
        self.episodes = sorted(os.listdir(os.path.join(data_root, "npz")))
        self.image_size = image_size
    
    def distribute(self):
        """Distribute the dataset across multiple processes."""
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        self.episodes = self.episodes[rank::world_size]
        print(f"Process {rank} has {len(self.episodes)} episodes.")

    def __len__(self):
        return len(self.episodes)
    
    def __getitem__(self, idx) -> Record:
        episode = self.episodes[idx]
        episode_path = os.path.join(self.data_root, "npz", episode)
        data = np.load(episode_path, allow_pickle=True)
        
        if self.image_size is not None:
            rgb = data['rgb']
            resized_rgb = []
            for img in rgb:
                img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
                resized_rgb.append(img)
            data['rgb'] = np.stack(resized_rgb, axis=0)

        ret = Record()
        for key in data:
            ret[key] = torch.from_numpy(data[key])
        ret['continued_mask'] = torch.ones((ret['rgb'].shape[0],), dtype=torch.bool)
        ret['continued_mask'][0] = 0  # first frame is always a reset
        ret['prev_actions'] = torch.zeros_like(ret['demonstration'], dtype=torch.long)
        ret['prev_actions'][1:] = ret['demonstration'][:-1]

        ret.set_metadata(
            episode=episode[:-4],
            episode_length=ret['rgb'].shape[0],
        )
        return ret
    
    def _update_batch(self, batch, episode, cur_pos) -> int:
        episode_length = episode["rgb"].shape[0]
        length = min(episode_length - cur_pos, self.num_steps - batch['length'])
        if length <= 0:
            return cur_pos

        for key in episode:
            if key not in batch:
                batch[key] = []
            batch[key].append(episode[key][cur_pos:cur_pos + length])
        batch['length'] += length
        return cur_pos + length

    def _finalize_batch(self, batch) -> Record:
        for key in batch:
            if isinstance(batch[key], list):
                batch[key] = torch.cat(batch[key], dim=0)
        batch.pop('length', None)  # Remove length from the batch
        return Record(**batch)

    def _iter_worker(self, queue: Queue):
        cur_batch = dict(length=0)
        idxes = np.arange(len(self))
        if self.shuffle:
            np.random.shuffle(idxes)
        for i in idxes:
            episode = self[i]
            episode_length = episode["rgb"].shape[0]
            cur_pos = 0
            while cur_pos < episode_length:
                cur_pos = self._update_batch(cur_batch, episode, cur_pos)
                if cur_batch['length'] >= self.num_steps:
                    queue.put(self._finalize_batch(cur_batch))
                    cur_batch = dict(length=0)
        if not self.drop_last and cur_batch['length'] > 0:
            queue.put(self._finalize_batch(cur_batch))

    def data_loader(self, shuffle=False, num_steps=10, num_envs=1, drop_last=True):
        self.shuffle = shuffle
        self.num_steps = num_steps
        self.num_envs = num_envs
        self.drop_last = drop_last and (self.num_envs == 1)  # Drop last only if single environment

        queues = [Queue(maxsize=5) for _ in range(num_envs)]
        threadings = [threading.Thread(target=self._iter_worker, args=(queues[i],)) for i in range(num_envs)]
        for t in threadings:
            t.start()
        while True:
            cur_batch = []
            for i in range(self.num_envs):
                data = queues[i].get()
                if data is None:
                    for j in range(self.num_envs):
                        for _ in range(10):
                            queues[j].get_nowait()
                        threadings[j].join(timeout=1)
                    return
                cur_batch.append(data)
            cur_batch = Record.stack(cur_batch, dim=1)
            yield cur_batch
        

