import re
import os
import torch as th
import numpy as np
import h5py
from sys import stderr
from typing import Optional, Dict, Any


class OfflineSample():
    def __init__(self, data, batch_size, max_seq_length, device="cpu"):
        self.batch_size = batch_size
        self.max_seq_length = max_seq_length
        self.data = data
        self.device = device
        for k, v in self.data.items():
            self.data[k] = v[:, :max_seq_length]
        # Data is already on device, no need to move

    def __getitem__(self, item):
        if isinstance(item, str):
            if item in self.data:
                return self.data[item]
            elif hasattr(self, item):
                return getattr(self, item)
            else:
                raise ValueError('Cannot index OfflineSample with key "{}"'.format(item))
        else:
            raise ValueError('Cannot index OfflineSample with key "{}"'.format(item))

    def to(self, device):
        for k, v in self.data.items():
            self.data[k] = v.to(device)
        self.device = device

    def keys(self):
        return list(self.data.keys())


class OfflineBufferH5():
    def __init__(self, datapaths, offline_data_size=2000, device="cuda", random_sample=True, val=False, val_size=0.8):
        offline_data_size = 100000000 if offline_data_size <= 0 else offline_data_size

        dataset_sources = len(datapaths)
        data_size_per_source = offline_data_size // dataset_sources
        
        # Load data directly to CUDA
        self.data = self._load_and_preprocess_data(datapaths, data_size_per_source, random_sample, device)
        
        self.keys = list(self.data.keys())
        data_size = self.data[self.keys[0]].shape[0]
        
        if val:
            if data_size < 64:
                raise ValueError("Offline data size is too small, must be no smaller than 64")
            self.buffer_size = int(data_size * val_size)
            self.batch_size = int(self.buffer_size * val_size)
            self.episodes_in_buffer = int(self.buffer_size * val_size)
            self.val_buffer_size = int(data_size * (1 - val_size))
        else:
            self.buffer_size = data_size
            self.batch_size = self.buffer_size
            self.episodes_in_buffer = self.buffer_size

        self.device = device

    def _load_and_preprocess_data(self, datapaths, offline_data_size, random_sample, device):
        all_data = {}
        
        for path in datapaths:
            with h5py.File(path, 'r') as f:
                for k in f.keys():
                    if k not in all_data:
                        all_data[k] = f[k][:]
                    else:
                        all_data[k] = np.concatenate((all_data[k], f[k][:]), axis=0)

        keys = list(all_data.keys())
        original_buffer_size = all_data[keys[0]].shape[0]
        offline_data_size = min(original_buffer_size, offline_data_size)
        
        if random_sample:
            idx = np.random.choice(original_buffer_size, offline_data_size, replace=False)
            all_data = {k: v[idx] for k, v in all_data.items()}
        elif original_buffer_size > offline_data_size:
            all_data = {k: v[-offline_data_size:] for k, v in all_data.items()}

        data = {}
        for k, v in all_data.items():
            if v.dtype == np.float32:
                # Convert to tensor and move directly to device
                data[k] = th.from_numpy(v).float().to(device)
            elif v.dtype == np.float64:
                data[k] = th.from_numpy(v).float().to(device)
            elif v.dtype == np.int32:
                data[k] = th.from_numpy(v).long().to(device)
            elif v.dtype == np.int64:
                data[k] = th.from_numpy(v).long().to(device)
            elif v.dtype == np.uint8:
                data[k] = th.from_numpy(v).long().to(device)
            else:
                data[k] = th.from_numpy(v).to(device)
        
        return data

    def max_t_filled(self, filled):
        return th.sum(filled, 1).max(0)[0]

    def can_sample(self, batch_size):
        return self.episodes_in_buffer >= batch_size
    
    def sample(self, batch_size, seq_length=0):
        # Generate random indices on device
        ep_ids = th.randperm(self.episodes_in_buffer, device=self.device)[:batch_size]
        
        # All operations are on device
        episode_data = {k: v[ep_ids] for k, v in self.data.items()}
        
        filled = episode_data['filled']
        max_ep_t = self.max_t_filled(filled).item()
        
        if max_ep_t >= seq_length > 0:
            # Generate sequence start indices on device
            seq_start = th.randint(0, int(max_ep_t - seq_length + 1), (batch_size,), device=self.device)
        
            batch_indices = th.arange(batch_size, device=self.device).unsqueeze(1)
            time_indices = seq_start.unsqueeze(1) + th.arange(seq_length, device=self.device).unsqueeze(0)
        
            truncated_data = {}
            for k, v in episode_data.items():
                if len(v.shape) == 2:
                    truncated_data[k] = v[batch_indices, time_indices].squeeze(1)
                elif len(v.shape) == 3:
                    truncated_data[k] = v[batch_indices, time_indices, :]
                elif len(v.shape) == 4:
                    truncated_data[k] = v[batch_indices, time_indices, :, :]
                else:
                    truncated_data[k] = v[batch_indices, time_indices]
            
            batch_sample = OfflineSample(truncated_data, batch_size, seq_length, device=self.device)
        else:
            batch_sample = OfflineSample(episode_data, batch_size, max_ep_t, device=self.device)
        
        return batch_sample

    def val_sample(self, batch_size):
        # Generate random indices on device
        ep_ids = th.randperm(self.val_buffer_size, device=self.device)[:batch_size] + self.buffer_size
        episode_data = {k: v[ep_ids] for k, v in self.data.items()}
        filled = episode_data['filled']
        max_ep_t = self.max_t_filled(filled).item()
        batch_sample = OfflineSample(episode_data, batch_size, max_ep_t, device=self.device)
        return batch_sample


class OfflineBufferPickle():
    def __init__(self, datapaths, offline_data_size=2000, device="cuda", random_sample=True, val=False, val_size=0.8):
        offline_data_size = 100000000 if offline_data_size <= 0 else offline_data_size

        dataset_sources = len(datapaths)
        data_size_per_source = offline_data_size // dataset_sources
        
        raw_data = []
        for i in range(dataset_sources):
            raw_data.extend(self._read_data(datapaths[i], data_size_per_source, random_sample))

        data_size = len(raw_data)
        
        if val:
            if data_size < 64:
                raise ValueError("Offline data size is too small, must be no smaller than 64")
            self.buffer_size = int(data_size * val_size)
            self.batch_size = int(self.buffer_size * val_size)
            self.episodes_in_buffer = int(self.buffer_size * val_size)
            self.val_buffer_size = int(data_size * (1 - val_size))
        else:
            self.buffer_size = data_size
            self.batch_size = self.buffer_size
            self.episodes_in_buffer = self.buffer_size
        
        self.keys = list(raw_data[0].keys())
        self.device = device
        
        # Stack episodes and move directly to device
        self.data = self._stack_episodes(raw_data)

    def _read_data(self, datapaths, offline_data_size, random_sample):
        data = []
        for path in (datapaths):
            data.extend(th.load(path))
            if not random_sample:
                if len(data) > offline_data_size:
                    data = data[-offline_data_size:]
        
        original_buffer_size = len(data)
        offline_data_size = min(original_buffer_size, offline_data_size)

        if random_sample:
            idx = np.random.choice(len(data), offline_data_size, replace=False)
            data = [data[i] for i in idx]
        return data

    def _stack_episodes(self, episodes):
        max_seq_length = max(ep['filled'].shape[0] for ep in episodes)
        
        stacked_data = {}
        for k in self.keys:
            if k == 'filled':
                padded = [th.nn.functional.pad(ep[k], (0, 0, 0, max_seq_length - ep[k].shape[0])) for ep in episodes]
                stacked_data[k] = th.stack(padded, dim=0).to(self.device)
            elif len(episodes[0][k].shape) == 1:
                padded = [th.nn.functional.pad(ep[k], (0, max_seq_length - ep[k].shape[0])) for ep in episodes]
                stacked_data[k] = th.stack(padded, dim=0).unsqueeze(-1).to(self.device)
            elif len(episodes[0][k].shape) == 2:
                padded = [th.nn.functional.pad(ep[k], (0, 0, 0, max_seq_length - ep[k].shape[0])) for ep in episodes]
                stacked_data[k] = th.stack(padded, dim=0).to(self.device)
            elif len(episodes[0][k].shape) == 3:
                padded = [th.nn.functional.pad(ep[k], (0, 0, 0, 0, 0, max_seq_length - ep[k].shape[0])) for ep in episodes]
                stacked_data[k] = th.stack(padded, dim=0).to(self.device)
            else:
                raise ValueError(f"Unsupported tensor shape for key {k}")
        
        return stacked_data

    def max_t_filled(self, filled):
        return th.sum(filled, 1).max(0)[0]

    def can_sample(self, batch_size):
        return self.episodes_in_buffer >= batch_size
    
    def sample(self, batch_size, seq_length=0):
        # Generate random indices on device
        ep_ids = th.randperm(self.episodes_in_buffer, device=self.device)[:batch_size]
        
        episode_data = {k: v[ep_ids] for k, v in self.data.items()}
        
        filled = episode_data['filled']
        max_ep_t = self.max_t_filled(filled).item()
        
        if max_ep_t >= seq_length > 0:
            # Generate sequence start indices on device
            seq_start = th.randint(0, int(max_ep_t - seq_length + 1), (batch_size,), device=self.device)
            
            batch_indices = th.arange(batch_size, device=self.device).unsqueeze(1)
            time_indices = seq_start.unsqueeze(1) + th.arange(seq_length, device=self.device).unsqueeze(0)
            
            truncated_data = {}
            for k, v in episode_data.items():
                if len(v.shape) == 2:
                    truncated_data[k] = v[batch_indices, time_indices].squeeze(1)
                elif len(v.shape) == 3:
                    truncated_data[k] = v[batch_indices, time_indices, :]
                elif len(v.shape) == 4:
                    truncated_data[k] = v[batch_indices, time_indices, :, :]
                else:
                    truncated_data[k] = v[batch_indices, time_indices]
            
            batch_sample = OfflineSample(truncated_data, batch_size, seq_length, device=self.device)
        else:
            batch_sample = OfflineSample(episode_data, batch_size, max_ep_t, device=self.device)
        
        return batch_sample
    
    def val_sample(self, batch_size):
        # Generate random indices on device
        ep_ids = th.randperm(self.val_buffer_size, device=self.device)[:batch_size] + self.buffer_size
        episode_data = {k: v[ep_ids] for k, v in self.data.items()}
        filled = episode_data['filled']
        max_ep_t = self.max_t_filled(filled).item()

        batch_sample = OfflineSample(episode_data, batch_size, max_ep_t, device=self.device)
        return batch_sample


class OfflineBuffer():
    def __init__(self, env, map_name, quality, data_folder=None, dataset_folder='dataset', offline_data_size=2000, device="cuda", random_sample=True, val=False, val_size=0.8):
        datapaths = []
        if quality == 'medium-expert':
            datapaths.extend(self._load_data_sources(dataset_folder, env, map_name, 'medium', data_folder))
            datapaths.extend(self._load_data_sources(dataset_folder, env, map_name, 'expert', data_folder))
        else:
            datapaths.extend(self._load_data_sources(dataset_folder, env, map_name, quality, data_folder))
        
        if all(['pkl' in path for path in datapaths]):
            self.buffer = OfflineBufferPickle(datapaths, offline_data_size=offline_data_size, device=device, random_sample=random_sample, val=val, val_size=val_size)
        elif all(['h5' in path for path in datapaths]):
            self.buffer = OfflineBufferH5(datapaths, offline_data_size=offline_data_size, device=device, random_sample=random_sample, val=val, val_size=val_size)
        else:
            raise ValueError("Cannot find parser for data files including {}".format(datapaths))

        self.buffer_size = self.buffer.buffer_size
        self.batch_size = self.buffer.buffer_size
        self.episodes_in_buffer = self.buffer.buffer_size
        self.device = device
    
    def _load_data_sources(self, dataset_folder, env, map_name, quality, data_folder):
        if env == 'gymma':
            env, map_name = map_name.split(':')
        datapath = os.path.join(dataset_folder, env, map_name, quality)
        assert os.path.exists(datapath), "Offline data path {} does not exist".format(datapath)

        if data_folder is None or data_folder == '':
            existing_folders = [ f for f in sorted(os.listdir(datapath)) if os.path.isdir(os.path.join(datapath, f)) ]
            assert len(existing_folders) > 0
            data_folder = existing_folders[-1]
        
        dataset_path = os.path.join(datapath, data_folder)
        assert os.path.exists(dataset_path), 'Offline data path {} does not exist'.format(dataset_path)
        self.dataset_path = dataset_path
        print('Load dataset from {}'.format(dataset_path), file=stderr)

        filenames = os.listdir(dataset_path)
        if any(['part' in f for f in filenames]):
            datafiles = [f for f in filenames if 'part' in f]
            max_parts = max([ int(re.match(r'part_(\d+)\..*', file).group(1)) for file in datafiles if re.match(r'part_(\d+)\..*', file) is not None ])
            ext_name = os.path.splitext(datafiles[0])[1]
            datafiles = [ 'part_{}{}'.format(i, ext_name) for i in range(max_parts + 1) ]
        else:
            datafiles = filenames

        datapaths = [os.path.join(dataset_path, f) for f in datafiles]
        assert len(datapaths) > 0, 'dataset path {} contains no readable data files'.format(dataset_path)
        return datapaths

    def max_t_filled(self, filled):
        return self.buffer.max_t_filled(filled)

    def can_sample(self, batch_size):
        return self.buffer.can_sample(batch_size)

    def sample(self, batch_size, seq_length=0):
        return self.buffer.sample(batch_size, seq_length)
    
    def val_sample(self, batch_size):
        return self.buffer.val_sample(batch_size)
    
    def get_stats(self, gamma=0.99):
        from utils.calc import compute_q_values
        rewards = self.buffer.data['reward']
        returns = compute_q_values(rewards, discount_factor=gamma)[:, 0]
        return_mean = np.mean(returns.cpu().numpy())  # Only move to CPU for stats calculation
        return_std = np.std(returns.cpu().numpy())
        filled_sum = th.sum(self.buffer.data['filled'], dim=1).cpu().numpy()
        length_mean = np.mean(filled_sum)
        
        infos = {
            'length': length_mean,
            'return_mean': return_mean,
            'return_std': return_std,
            'size': self.buffer_size,
        }
        
        return infos

class PartialOfflineBuffer():
    def __init__(self, tasks: list[str], task2offline_buffer: dict[str, OfflineBuffer], n_sample_per_task=20):
        self.tasks = tasks
        self.task2id = {t: i for i, t in enumerate(self.tasks)}
        self.n_task = len(self.tasks)
        self.n_sample_per_task = n_sample_per_task
        if n_sample_per_task == 0: return
        self.n_sample = n_sample_per_task * self.n_task
        
        # [ task 1 samples (train) | task 2 samples (trans) | ... | task K samples (trans) ]
        self.task2parbuf = {}
        for t in self.tasks:
            self.task2parbuf[t] = task2offline_buffer[t].sample(n_sample_per_task)
            
    def sample(self, batch_size, cur_task):  # sample only from preceding tasks: t = 1...(task-1)
        if self.n_sample_per_task == 0:
            return None, cur_task
        if batch_size > self.n_sample_per_task:
            batch_size = self.n_sample_per_task
            
        task_id = self.task2id[cur_task]
        if task_id == 0:
            return None, cur_task
        prev_task_id = np.random.randint(0, task_id)
        prev_task = self.tasks[prev_task_id]

        # Efficiently create a buffer with only the selected indices
        parbuf = self.task2parbuf[prev_task]
        idx = np.random.choice(self.n_sample_per_task, batch_size, replace=False)
        # Convert indices to tensor on device
        idx_tensor = th.tensor(idx, device=parbuf.device)
        # Only slice the data, no deepcopy
        sliced_data = {k: v[idx_tensor] for k, v in parbuf.data.items()}
        offsample = OfflineSample(sliced_data, batch_size, parbuf.max_seq_length, device=parbuf.device)
        return offsample, prev_task
    
class DataSaver():
    def __init__(self, datadir, max_size=2000):
        os.makedirs(datadir, exist_ok=True)
        self.datadir = datadir
        self.max_size = max_size
        self.data_batch = []
        self.part_no = 0

    def append(self, data):
        self.data_batch.append(data)
        if len(self.data_batch) >= self.max_size:
            self.save_batch()
            
    def save_batch(self):
        if len(self.data_batch) > 0:
            keys = list(self.data_batch[0].keys())
            datadic = {k: [] for k in keys}
            for d in self.data_batch:
                for k in keys:
                    if isinstance(d[k], th.Tensor):
                        datadic[k].append(d[k].numpy())
                    else:
                        datadic[k].append(d[k])
            datadic = {k: np.concatenate(v) for k, v in datadic.items()}

            with h5py.File(os.path.join(self.datadir, "part_{}.h5".format(self.part_no)), 'w') as file:
                for k, v in datadic.items():
                    file.create_dataset(k, data=v, compression='gzip', compression_opts=9)

            self.data_batch.clear()
            self.part_no += 1

    def close(self):
        self.save_batch()
        return self.datadir
