import torch as th
import numpy as np
from types import SimpleNamespace as SN
from torch.autograd import Variable

class EpisodeBatch:
    def __init__(self,
                 scheme,
                 groups,
                 batch_size,
                 max_seq_length,
                 data=None,
                 preprocess=None,
                 device="cpu"):
        self.scheme = scheme.copy()
        self.groups = groups
        self.batch_size = batch_size
        self.max_seq_length = max_seq_length
        self.preprocess = {} if preprocess is None else preprocess
        self.device = device

        if data is not None:
            self.data = data
        else:
            self.data = SN()
            self.data.transition_data = {}
            self.data.episode_data = {}
            self._setup_data(self.scheme, self.groups, batch_size, max_seq_length, self.preprocess)

    def _setup_data(self, scheme, groups, batch_size, max_seq_length, preprocess):
        if preprocess is not None:
            for k in preprocess:
                assert k in scheme
                new_k = preprocess[k][0]
                transforms = preprocess[k][1]

                vshape = self.scheme[k]["vshape"]
                dtype = self.scheme[k]["dtype"]
                for transform in transforms:
                    vshape, dtype = transform.infer_output_info(vshape, dtype)

                self.scheme[new_k] = {
                    "vshape": vshape,
                    "dtype": dtype
                }
                if "group" in self.scheme[k]:
                    self.scheme[new_k]["group"] = self.scheme[k]["group"]
                if "episode_const" in self.scheme[k]:
                    self.scheme[new_k]["episode_const"] = self.scheme[k]["episode_const"]

        assert "filled" not in scheme, '"filled" is a reserved key for masking.'
        scheme.update({
            "filled": {"vshape": (1,), "dtype": th.long},
        })

        for field_key, field_info in scheme.items():
            assert "vshape" in field_info, "Scheme must define vshape for {}".format(field_key)
            vshape = field_info["vshape"]
            episode_const = field_info.get("episode_const", False)
            group = field_info.get("group", None)
            dtype = field_info.get("dtype", th.float32)

            if isinstance(vshape, int):
                vshape = (vshape,)

            if group:
                assert group in groups, "Group {} must have its number of members defined in _groups_".format(group)
                shape = (groups[group], *vshape)# n_num, v_shape
            else:
                shape = vshape

            if episode_const:
                self.data.episode_data[field_key] = th.zeros((batch_size, *shape), dtype=dtype, device=self.device)
            else:
                self.data.transition_data[field_key] = th.zeros((batch_size, max_seq_length, *shape), dtype=dtype, device=self.device)

    def extend(self, scheme, groups=None):
        self._setup_data(scheme, self.groups if groups is None else groups, self.batch_size, self.max_seq_length)

    def to(self, device):
        for k, v in self.data.transition_data.items():
            if k=='obs' or k=='state':
                self.data.transition_data[k] = Variable(v.to(device),requires_grad=True)
            else:
                self.data.transition_data[k] = v.to(device)
        for k, v in self.data.episode_data.items():
            if k=='obs' or k=='state':
                self.data.episode_data[k] = Variable(v.to(device),requires_grad=True)
            else:
                self.data.episode_data[k] = v.to(device)
        self.device = device

    def update(self, data, bs=slice(None), ts=slice(None), mark_filled=True):
        slices = self._parse_slices((bs, ts))
        processed_data = {}
        for k, v in data.items():
            dtype = self.scheme[k].get("dtype", th.float32)

            if not th.is_tensor(v):
                if isinstance(v, list):
                    v = np.array(v, dtype=np.float32 if dtype == th.float32 else None)
                processed_v = th.as_tensor(v, dtype=dtype, device=self.device)
            else:
                if v.device != self.device or v.dtype != dtype:
                    processed_v = v.to(self.device, dtype=dtype, non_blocking=False)
                else:
                    processed_v = v
            
            processed_data[k] = processed_v
        for k, v in processed_data.items():
            if k in self.data.transition_data:
                target = self.data.transition_data
                if mark_filled:
                    target["filled"][slices] = 1
                    mark_filled = False
                _slices = slices
            elif k in self.data.episode_data:
                target = self.data.episode_data
                _slices = slices[0]
            else:
                raise KeyError("{} not found in transition or episode data".format(k))

            special_handling = False
            dest = target[k][_slices]

            if (len(v.shape) == len(dest.shape) == 4 and 
                v.shape[0] == dest.shape[0] and 
                v.shape[1] == dest.shape[2] and
                k == "enemy_actions"):


                self._check_safe_view(v, dest)

                transposed_v = v.permute(0, 2, 1, 3)

                target[k][_slices] = transposed_v
                special_handling = True

            elif (len(v.shape) == 3 and len(dest.shape) == 4 and 
                k == "enemy_actions" and v.numel() < dest.numel()):

                self._check_safe_view(v, dest)

                if v.shape[0] == dest.shape[0] and v.shape[1] == dest.shape[1]:
                    actual_enemies = v.shape[2]
                    target_enemies = dest.shape[2]

                    padded_v = th.zeros_like(dest)

                    if actual_enemies <= target_enemies:
                        padded_v[:, :, :actual_enemies, :] = v.view(v.shape[0], v.shape[1], actual_enemies, 1)
                    else:
                        padded_v = v[:, :, :target_enemies].view(v.shape[0], v.shape[1], target_enemies, 1)

                    target[k][_slices] = padded_v
                    special_handling = True

            if not special_handling:
                self._check_safe_view(v, dest)
                try:
                    target[k][_slices] = v.view_as(target[k][_slices])
                except RuntimeError as e:
                    if "invalid for input of size" in str(e):
                        padded_v = th.zeros_like(target[k][_slices])

                        fill_size = min(v.numel(), padded_v.numel())

                        flat_v = v.reshape(-1)
                        flat_padded = padded_v.reshape(-1)
                        flat_padded[:fill_size] = flat_v[:fill_size]

                        target[k][_slices] = padded_v
                    else:
                        raise

            if k in self.preprocess:
                new_k = self.preprocess[k][0]
                v = target[k][_slices]
                for transform in self.preprocess[k][1]:
                    v = transform.transform(v)
                target[new_k][_slices] = v.view_as(target[new_k][_slices])

    def _check_safe_view(self, v, dest):

        if v.numel() == 0:
            return
            
        if len(v.shape) > 0 and len(dest.shape) > 0:

            if (len(v.shape) == len(dest.shape) == 4 and 
                v.shape[0] == dest.shape[0] and
                v.shape[1] == dest.shape[2] and
                v.shape[0] * v.shape[1] * v.shape[2] * v.shape[3] == dest.shape[0] * dest.shape[1] * dest.shape[2] * dest.shape[3]):
                return

            if (len(v.shape) <= len(dest.shape) and 
                v.shape[0] != dest.shape[0] and
                v.shape[-1] == dest.shape[-2]):
                return

            if (len(v.shape) == 3 and len(dest.shape) == 4 and 
                v.shape[0] == dest.shape[0] and v.shape[1] == dest.shape[1]):
                return

            if (len(v.shape) >= 3 and len(dest.shape) >= 4 and 
                v.shape[0] == dest.shape[0] and 
                v.shape[-1] == dest.shape[-1] and
                v.shape[1] == dest.shape[2] * 2):

                return
        
        # 原有的安全检查逻辑
        idx = len(v.shape) - 1
        for s in dest.shape[::-1]:
            if idx < 0:

                if s != 1:
                    raise ValueError("Unsafe reshape of {} to {}".format(v.shape, dest.shape))
            elif v.shape[idx] != s:
                if s != 1:
                    raise ValueError("Unsafe reshape of {} to {}".format(v.shape, dest.shape))
            else:
                idx -= 1

    def __getitem__(self, item):
        if isinstance(item, str):
            if item in self.data.episode_data:
                return self.data.episode_data[item]
            elif item in self.data.transition_data:
                return self.data.transition_data[item]
            else:
                raise ValueError
        elif isinstance(item, tuple) and all([isinstance(it, str) for it in item]):
            new_data = self._new_data_sn()
            for key in item:
                if key in self.data.transition_data:
                    new_data.transition_data[key] = self.data.transition_data[key]
                elif key in self.data.episode_data:
                    new_data.episode_data[key] = self.data.episode_data[key]
                else:
                    raise KeyError("Unrecognised key {}".format(key))

            new_scheme = {key: self.scheme[key] for key in item}
            new_groups = {self.scheme[key]["group"]: self.groups[self.scheme[key]["group"]]
                          for key in item if "group" in self.scheme[key]}
            ret = EpisodeBatch(new_scheme, new_groups, self.batch_size, self.max_seq_length, data=new_data, device=self.device)
            return ret
        else:
            item = self._parse_slices(item)
            new_data = self._new_data_sn()
            for k, v in self.data.transition_data.items():
                new_data.transition_data[k] = v[item]
            for k, v in self.data.episode_data.items():
                new_data.episode_data[k] = v[item[0]]

            ret_bs = self._get_num_items(item[0], self.batch_size)
            ret_max_t = self._get_num_items(item[1], self.max_seq_length)

            ret = EpisodeBatch(self.scheme, self.groups, ret_bs, ret_max_t, data=new_data, device=self.device)
            return ret
    def _get_num_items(self, indexing_item, max_size):
        if isinstance(indexing_item, list) or isinstance(indexing_item, np.ndarray):
            return len(indexing_item)
        elif isinstance(indexing_item, slice):
            _range = indexing_item.indices(max_size)
            return 1 + (_range[1] - _range[0] - 1)//_range[2]

    def _new_data_sn(self):
        new_data = SN()
        new_data.transition_data = {}
        new_data.episode_data = {}
        return new_data

    def _parse_slices(self, items):
        parsed = []
        # Only batch slice given, add full time slice
        if (isinstance(items, slice)  # slice a:b
            or isinstance(items, int)  # int i
            or (isinstance(items, (list, np.ndarray, th.LongTensor, th.cuda.LongTensor)))  # [a,b,c]
            ):
            items = (items, slice(None))

        # Need the time indexing to be contiguous
        if isinstance(items[1], list):
            raise IndexError("Indexing across Time must be contiguous")

        for item in items:
            #TODO: stronger checks to ensure only supported options get through
            if isinstance(item, int):
                # Convert single indices to slices
                parsed.append(slice(item, item+1))
            else:
                # Leave slices and lists as is
                parsed.append(item)
        return parsed

    def max_t_filled(self):
        return th.sum(self.data.transition_data["filled"], 1).max(0)[0]

    def __repr__(self):
        return "EpisodeBatch. Batch Size:{} Max_seq_len:{} Keys:{} Groups:{}".format(self.batch_size,
                                                                                     self.max_seq_length,
                                                                                     self.scheme.keys(),
                                                                                     self.groups.keys())

    def pop(self, k):
        if k in self.data.transition_data:
            del self.data.transition_data[k]
        if k in self.data.episode_data:
            del self.data.episode_data[k]
class ReplayBuffer(EpisodeBatch):
    def __init__(self, scheme, groups, buffer_size, max_seq_length, preprocess=None, device="cpu"):
        super(ReplayBuffer, self).__init__(scheme, groups, buffer_size, max_seq_length, preprocess=preprocess, device=device)
        self.buffer_size = buffer_size  # same as self.batch_size but more explicit
        self.buffer_index = 0
        self.episodes_in_buffer = 0

    def insert_episode_batch(self, ep_batch):
        if self.buffer_index + ep_batch.batch_size <= self.buffer_size:
            self.update(ep_batch.data.transition_data,
                        slice(self.buffer_index, self.buffer_index + ep_batch.batch_size),
                        slice(0, ep_batch.max_seq_length),
                        mark_filled=False)
            self.update(ep_batch.data.episode_data,
                        slice(self.buffer_index, self.buffer_index + ep_batch.batch_size))
            self.buffer_index = (self.buffer_index + ep_batch.batch_size)
            self.episodes_in_buffer = max(self.episodes_in_buffer, self.buffer_index)
            self.buffer_index = self.buffer_index % self.buffer_size
            assert self.buffer_index < self.buffer_size
        else:
            buffer_left = self.buffer_size - self.buffer_index
            self.insert_episode_batch(ep_batch[0:buffer_left, :])
            self.insert_episode_batch(ep_batch[buffer_left:, :])

    def can_sample(self, batch_size):
        return self.episodes_in_buffer >= batch_size

    def sample(self, batch_size):
        assert self.can_sample(batch_size)
        if self.episodes_in_buffer == batch_size:
            return self[:batch_size]
        else:
            # Uniform sampling only atm
            ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False)
            return self[ep_ids]

    def __repr__(self):
        return "ReplayBuffer. {}/{} episodes. Keys:{} Groups:{}".format(self.episodes_in_buffer,
                                                                        self.buffer_size,
                                                                        self.scheme.keys(),
                                                                        self.groups.keys())


if __name__ == "__main__":
    bs = 4
    n_agents = 2
    groups = {"agents": n_agents}

    scheme = {
        "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
        "obs": {"vshape": (3,), "group": "agents"},
        "state": {"vshape": (3,3)},
        "epsilon": {"vshape": (1,), "episode_const": True}
    }
    from transforms import OneHot
    preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=5)])
    }

    ep_batch = EpisodeBatch(scheme, groups, bs, 3, preprocess=preprocess)

    env_data = {
        "actions": th.ones(n_agents, 1).long(),
        "obs": th.ones(2, 3),
        "state": th.eye(3)
    }
    batch_data = {
        "actions": th.ones(bs, n_agents, 1).long(),
        "obs": th.ones(2, 3).unsqueeze(0).repeat(bs,1,1),
        "state": th.eye(3).unsqueeze(0).repeat(bs,1,1),
    }

    ep_batch.update(env_data, 0, 0)

    ep_batch.update({"epsilon": th.ones(bs)*.05})

    ep_batch[:, 1].update(batch_data)
    ep_batch.update(batch_data, ts=1)

    ep_batch.update(env_data, 0, 1)

    env_data = {
        "obs": th.ones(2, 3),
        "state": th.eye(3)*2
    }
    ep_batch.update(env_data, 3, 0)

    b2 = ep_batch[0, 1]
    b2.update(env_data, 0, 0)

    replay_buffer = ReplayBuffer(scheme, groups, 5, 3, preprocess=preprocess)

    replay_buffer.insert_episode_batch(ep_batch)

    replay_buffer.insert_episode_batch(ep_batch)

    sampled = replay_buffer.sample(3)

    print(sampled["actions_onehot"])