import re
import os
import torch as th
import numpy as np
import h5py
from sys import stderr
import copy

from utils.calc import compute_q_values

class OfflineSample():
    def __init__(self, data, batch_size, max_seq_length, device="cpu"):
        self.batch_size = batch_size
        self.max_seq_length = max_seq_length
        self.data = data
        self.device = device
        for k, v in self.data.items():
            self.data[k] = v[:, :max_seq_length].to(self.device)

    def __getitem__(self, item):
        if isinstance(item, str):
            if item in self.data:
                return self.data[item]
            elif hasattr(self, item):
                return getattr(self, item)
            else:
                raise ValueError('Cannot index OfflineSample with key "{}"'.format(item))
        else:
            raise ValueError('Cannot index OfflineSample with key "{}"'.format(item))

    def to(self, device):
        for k, v in self.data.items():
            self.data[k] = v.to(device)
        self.device = device

    def keys(self):
        return list(self.data.keys())


class OfflineBufferH5FullData():
    def __init__(self, datapath, offline_data_size=2000, device="cpu", random_sample=True):
        self.data = h5py.File(datapath, 'r')
        self.keys = list(self.data.keys())
        self.device = device
        original_buffer_size = self.data[self.keys[0]].shape[0]

        offline_data_size = original_buffer_size if offline_data_size > original_buffer_size or offline_data_size <= 0 else offline_data_size

        if random_sample:
            self.chosen_idx = np.random.choice(original_buffer_size, offline_data_size, replace=False)
        else:
            self.chosen_idx = np.array(range(original_buffer_size - offline_data_size, original_buffer_size))
        
        self.buffer_size = offline_data_size
        self.batch_size = self.buffer_size
        self.episodes_in_buffer = self.buffer_size
        
    def __del__(self):
        self.data.close()

    def max_t_filled(self, filled):
        return th.sum(filled, 1).max(0)[0]

    def can_sample(self, batch_size):
        return self.episodes_in_buffer >= batch_size
    
    def sample(self, batch_size):
        ep_ids = np.sort(np.random.choice(self.chosen_idx, batch_size, replace=False))
        episode_data = {k: th.tensor(self.data[k][ep_ids]) for k in self.keys}
        filled = episode_data['filled']
        max_ep_t = self.max_t_filled(filled).item()
        batch_sample = OfflineSample(episode_data, batch_size, max_ep_t, device=self.device)
        return batch_sample


class OfflineBufferH5():
    def __init__(self, datapaths, offline_data_size=2000, device="cpu", random_sample=True, val=False, val_size=0.8):
        offline_data_size = 100000000 if offline_data_size <= 0 else offline_data_size

        dataset_sources = len(datapaths)
        data_size_per_source = offline_data_size // dataset_sources
        dataset = [ self._read_data(datapaths[i], data_size_per_source, random_sample) for i in range(dataset_sources) ] 
        self.data = {
            k: np.concatenate([v[k] for v in dataset], axis=0) for k in dataset[0].keys()
        }

        self.keys = list(self.data.keys())
        
        data_size = self.data[self.keys[0]].shape[0]
        
        if val:
            if data_size < 64:
                raise ValueError("Offline data size is too small, must be no smaller than 64")
            self.buffer_size = int(data_size * val_size)
            self.batch_size = int(self.buffer_size * val_size)
            self.episodes_in_buffer = int(self.buffer_size * val_size)
            self.val_buffer_size = int(data_size * (1 - val_size))
        else:
            self.buffer_size = data_size
            self.batch_size = self.buffer_size
            self.episodes_in_buffer = self.buffer_size

        self.device = device

    def _read_data(self, datapaths, offline_data_size, random_sample):
        data = {}
        for path in (datapaths):
            with h5py.File(path, 'r') as f:
                for k in f.keys():
                    if k not in data:
                        data[k] = f[k][:]
                    else:
                        data[k] = np.concatenate((data[k], f[k][:]), axis=0)
                    
            # pre-slice for memory releasing
            if not random_sample and data[list(data.keys())[0]].shape[0] > offline_data_size:
                data = {k: v[-offline_data_size:] for k, v in data.items()}

        keys = list(data.keys())
        original_buffer_size = len(data[keys[0]])
        offline_data_size = min(original_buffer_size, offline_data_size)
        
        if random_sample:
            idx = np.random.choice(original_buffer_size, offline_data_size, replace=False)
            data = {k: v[idx] for k, v in data.items()}
        
        return data

    def max_t_filled(self, filled):
        return th.sum(filled, 1).max(0)[0]

    def can_sample(self, batch_size):
        return self.episodes_in_buffer >= batch_size
    
    def sample(self, batch_size, seq_length=0):
        ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False)
        episode_data = {k: th.tensor(v[ep_ids]) for k, v in self.data.items()}
        filled = episode_data['filled']
        max_ep_t = self.max_t_filled(filled).item()
        
        if max_ep_t >= seq_length > 0:
            seq_start = th.randint(0, max_ep_t - seq_length + 1, (batch_size,))
            truncated_data = {}
            for k, v in episode_data.items():
                arr = []
                for i in range(batch_size):
                    arr.append(v[i, seq_start[i]:seq_start[i]+seq_length])
                truncated_data[k] = th.stack(arr, dim=0)
                
            batch_sample = OfflineSample(truncated_data, batch_size, seq_length, device=self.device)
        else:
            batch_sample = OfflineSample(episode_data, batch_size, max_ep_t, device=self.device)
        return batch_sample

    def val_sample(self, batch_size):
        ep_ids = np.random.choice(self.val_buffer_size, batch_size, replace=False) + self.buffer_size
        episode_data = {k: th.tensor(v[ep_ids]) for k, v in self.data.items()}
        filled = episode_data['filled']
        max_ep_t = self.max_t_filled(filled).item()
        batch_sample = OfflineSample(episode_data, batch_size, max_ep_t, device=self.device)
        return batch_sample

class OfflineBufferPickle():
    def __init__(self, datapaths, offline_data_size=2000, device="cpu", random_sample=True, val=False, val_size=0.8):
        offline_data_size = 100000000 if offline_data_size <= 0 else offline_data_size

        dataset_sources = len(datapaths)
        data_size_per_source = offline_data_size // dataset_sources
        self.data = []
        for i in range(dataset_sources):
            self.data.extend(self._read_data(datapaths[i], data_size_per_source, random_sample))

        data_size = len(self.data)
        
        if val:
            if data_size < 64:
                raise ValueError("Offline data size is too small, must be no smaller than 64")
            self.buffer_size = int(data_size * val_size)
            self.batch_size = int(self.buffer_size * val_size)
            self.episodes_in_buffer = int(self.buffer_size * val_size)
            self.val_buffer_size = int(data_size * (1 - val_size))
        else:
            self.buffer_size = data_size
            self.batch_size = self.buffer_size
            self.episodes_in_buffer = self.buffer_size
        
        self.keys = list(self.data[0].keys())
        self.device = device

    def _read_data(self, datapaths, offline_data_size, random_sample):
        data = []
        for path in (datapaths):
            data.extend(th.load(path))
            if not random_sample:
                if len(data) > offline_data_size:
                    data = data[-offline_data_size:]
        
        original_buffer_size = len(data)
        offline_data_size = min(original_buffer_size, offline_data_size)

        if random_sample:
            idx = np.random.choice(len(data), offline_data_size, replace=False)
            data = [data[i] for i in idx]
        return data

    def max_t_filled(self, filled):
        return th.sum(filled, 1).max(0)[0]

    def can_sample(self, batch_size):
        return self.episodes_in_buffer >= batch_size
    
    def sample(self, batch_size, seq_length=0):
        ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False)
        episode_data = [self.data[i] for i in ep_ids]
        filled = th.cat([d['filled'].to(self.device) for d in episode_data], dim=0)
        max_ep_t = self.max_t_filled(filled).item()
        
        if max_ep_t >= seq_length > 0: # NOTE need debugging
            seq_start = th.randint(0, max_ep_t - seq_length + 1, (batch_size,))
            truncated_data = []
            for d in episode_data:
                arr = []
                for i in range(batch_size):
                    arr.append(d[i, seq_start[i]:seq_start[i]+seq_length])
                truncated_data.append(th.stack(arr, dim=0))
            data = {
                k: th.cat( [d[k] for d in truncated_data], dim=0 ) for k in self.keys
            }
            batch_sample = OfflineSample(data, batch_size, max_ep_t, device=self.device)
        else:
            data = {
                k: th.cat( [d[k] for d in episode_data], dim=0 ) for k in self.keys
            }
            batch_sample = OfflineSample(data, batch_size, max_ep_t, device=self.device)
        return batch_sample
    
    def val_sample(self, batch_size):
        ep_ids = np.random.choice(self.val_buffer_size, batch_size, replace=False) + self.buffer_size
        episode_data = [self.data[i] for i in ep_ids]
        filled = th.cat([d['filled'].to(self.device) for d in episode_data], dim=0)
        max_ep_t = self.max_t_filled(filled).item()

        data = {
            k: th.cat( [d[k] for d in episode_data], dim=0 ) for k in self.keys
        }
        batch_sample = OfflineSample(data, batch_size, max_ep_t, device=self.device)
        return batch_sample


class OfflineBuffer():
    def __init__(self, env, map_name, quality, data_folder=None, dataset_folder='dataset', offline_data_size=2000, device="cpu", random_sample=True, val=False, val_size=0.8):
        datapaths = []
        if quality == 'medium-expert':
            datapaths.append(self._load_data_sources(dataset_folder, env, map_name, 'medium', data_folder))
            datapaths.append(self._load_data_sources(dataset_folder, env, map_name, 'expert', data_folder))
        else:
            datapaths.append(self._load_data_sources(dataset_folder, env, map_name, quality, data_folder))
        
        if all([all(['pkl' in f for f in paths]) for paths in datapaths]):
            self.buffer = OfflineBufferPickle(datapaths, offline_data_size=offline_data_size, device=device, random_sample=random_sample, val=val, val_size=val_size)
        elif all([all(['h5' in f for f in paths]) for paths in datapaths]):
            self.buffer = OfflineBufferH5(datapaths, offline_data_size=offline_data_size, device=device, random_sample=random_sample, val=val, val_size=val_size)
        else:
            raise ValueError("Cannot find parser for data files including {}".format(datapaths))

        self.buffer_size = self.buffer.buffer_size
        self.batch_size = self.buffer.buffer_size
        self.episodes_in_buffer = self.buffer.buffer_size
        self.device = device
    
    def _load_data_sources(self, dataset_folder, env, map_name, quality, data_folder):
        if env == 'gymma':
            env, map_name = map_name.split(':')
        datapath = os.path.join(dataset_folder, env, map_name, quality)
        assert os.path.exists(datapath), "Offline data path {} does not exist".format(datapath)

        if data_folder is None or data_folder == '':
            existing_folders = [ f for f in sorted(os.listdir(datapath)) if os.path.isdir(os.path.join(datapath, f)) ]
            assert len(existing_folders) > 0
            data_folder = existing_folders[-1]
        
        dataset_path = os.path.join(datapath, data_folder)
        assert os.path.exists(dataset_path), 'Offline data path {} does not exist'.format(dataset_path)
        self.dataset_path = dataset_path
        print('Load dataset from {}'.format(dataset_path), file=stderr)

        filenames = os.listdir(dataset_path)
        if any(['part' in f for f in filenames]):
            datafiles = [f for f in filenames if 'part' in f]
            max_parts = max([ int(re.match(r'part_(\d+)\..*', file).group(1)) for file in datafiles if re.match(r'part_(\d+)\..*', file) is not None ])
            ext_name = os.path.splitext(datafiles[0])[1]
            datafiles = [ 'part_{}{}'.format(i, ext_name) for i in range(max_parts + 1) ]
        else:
            datafiles = filenames

        datapaths = [os.path.join(dataset_path, f) for f in datafiles]
        assert len(datapaths) > 0, 'dataset path {} contains no readable data files'.format(dataset_path)
        return datapaths

    def max_t_filled(self, filled):
        return self.buffer.max_t_filled(filled)

    def can_sample(self, batch_size):
        return self.buffer.can_sample(batch_size)

    def sample(self, batch_size, seq_length=0):
        return self.buffer.sample(batch_size, seq_length)
    
    def val_sample(self, batch_size):
        return self.buffer.val_sample(batch_size)
    
    def get_stats(self, gamma=0.99):
        rewards = self.buffer.data['reward']
        returns = compute_q_values(rewards, discount_factor=gamma)[:, 0]
        return_mean = np.mean(returns)
        return_std = np.std(returns)
        length_mean = np.mean(np.sum(self.buffer.data['filled'], axis=1))
        
        infos = {
            'length': length_mean,
            'return_mean': return_mean,
            'return_std': return_std,
            'size': self.buffer_size, # TODO win_Rate for sc2
        }
        
        return infos

class PartialOfflineBuffer():
    def __init__(self, tasks: list[str], task2offline_buffer: dict[str, OfflineBuffer], n_sample_per_task=20):
        self.tasks = tasks
        self.task2id = {t: i for i, t in enumerate(self.tasks)}
        self.n_task = len(self.tasks)
        self.n_sample_per_task = n_sample_per_task
        if n_sample_per_task == 0: return
        self.n_sample = n_sample_per_task * self.n_task
        
        # [ task 1 samples (train) | task 2 samples (trans) | ... | task K samples (trans) ]
        self.task2parbuf = {}
        for t in self.tasks:
            self.task2parbuf[t] = task2offline_buffer[t].sample(n_sample_per_task)
            
    def sample(self, batch_size, cur_task):  # sample only from preceding tasks: t = 1...(task-1)
        if self.n_sample_per_task == 0:
            return None, cur_task
        if batch_size > self.n_sample_per_task:
            batch_size = self.n_sample_per_task
            
        task_id = self.task2id[cur_task]
        if task_id == 0:
            return None, cur_task
        prev_task_id = np.random.randint(0, task_id)
        prev_task = self.tasks[prev_task_id]

        # Efficiently create a buffer with only the selected indices
        parbuf = self.task2parbuf[prev_task]
        idx = np.random.choice(self.n_sample_per_task, batch_size, replace=False)
        # Only slice the data, no deepcopy
        sliced_data = {k: v[idx] for k, v in parbuf.data.items()}
        offsample = OfflineSample(sliced_data, batch_size, parbuf.max_seq_length, device=parbuf.device)
        return offsample, prev_task

class DataSaver():
    def __init__(self, datadir, max_size=2000):
        os.makedirs(datadir, exist_ok=True)
        self.datadir = datadir
        self.max_size = max_size
        self.data_batch = []
        self.part_no = 0

    def append(self, data):
        self.data_batch.append(data)
        if len(self.data_batch) >= self.max_size:
            self.save_batch()
            
    def save_batch(self):
        if len(self.data_batch) > 0:
            keys = list(self.data_batch[0].keys())
            datadic = {k: [] for k in keys}
            for d in self.data_batch:
                for k in keys:
                    if isinstance(d[k], th.Tensor):
                        datadic[k].append(d[k].numpy())
                    else:
                        datadic[k].append(d[k])
            datadic = {k: np.concatenate(v) for k, v in datadic.items()}

            with h5py.File(os.path.join(self.datadir, "part_{}.h5".format(self.part_no)), 'w') as file:
                for k, v in datadic.items():
                    file.create_dataset(k, data=v, compression='gzip', compression_opts=9)

            self.data_batch.clear()
            self.part_no += 1

    def close(self):
        self.save_batch()
        return self.datadir
