import torch
import numpy as np
from torch.utils.data import Dataset


def list2array(input_list):
    return np.array(input_list)


class SequentialDataset(Dataset):

    def __init__(self, context_length, states, obss, actions, done_idxs, rewards, timesteps, next_states,
                 next_obss, next_available_actions):
        self.context_length = context_length
        self.states = states
        self.obss = obss
        self.next_states = next_states
        self.next_obss = next_obss
        self.actions = actions
        self.next_available_actions = next_available_actions
        # done_idx - 1 equals the last step's position
        self.done_idxs = done_idxs
        self.rewards = rewards
        self.timesteps = timesteps

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        context_length = self.context_length
        done_idx = idx + context_length
        for i in np.array(self.done_idxs)[:, 0].tolist():
            if i > idx:  # first done_idx greater than idx
                done_idx = min(int(i), done_idx)
                break
        idx = done_idx - context_length
        states = torch.tensor(np.array(self.states[idx:done_idx]), dtype=torch.float32)
        next_states = torch.tensor(np.array(self.next_states[idx:done_idx]), dtype=torch.float32)
        obss = torch.tensor(np.array(self.obss[idx:done_idx]), dtype=torch.float32)
        next_obss = torch.tensor(np.array(self.next_obss[idx:done_idx]), dtype=torch.float32)

        if idx == 0 or idx - 1 in self.done_idxs:
            padding = list(np.zeros_like(self.actions[idx]))
            pre_actions = [padding] + self.actions[idx:done_idx - 1]
            pre_actions = torch.tensor(pre_actions, dtype=torch.int64)
        else:
            pre_actions = torch.tensor(self.actions[idx - 1:done_idx - 1], dtype=torch.int64)
        cur_actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.int64)
        next_available_actions = torch.tensor(self.next_available_actions[idx:done_idx], dtype=torch.int64)

        # actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long)
        rewards = torch.tensor(self.rewards[idx:done_idx], dtype=torch.float32).unsqueeze(-1)
        timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64)

        return states, obss, pre_actions, rewards, timesteps, next_states, next_obss, cur_actions, \
               next_available_actions

class ExpertDataSet(Dataset):
    def __init__(self, expert_data):
        self.state_size = np.shape(expert_data[0])[0]
        # self.expert_data = np.array(pd.read_csv(data_set_path))
        self.state = torch.tensor(torch.from_numpy(expert_data[0]), dtype=torch.float32)
        self.action = torch.tensor(torch.from_numpy(np.array(expert_data[1])), dtype=torch.float32)
        self.next_state = torch.tensor(torch.from_numpy(expert_data[0]), dtype=torch.float32)  # as the current state
        self.length = self.state_size

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.state[idx], self.action[idx]


class StateActionReturnDataset(Dataset):

    def __init__(self, data, block_size, actions, done_idxs, rtgs, timesteps):
        self.block_size = block_size
        self.data = data
        self.actions = actions
        self.done_idxs = done_idxs
        self.rtgs = rtgs
        self.timesteps = timesteps

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        block_size = self.block_size // 3
        done_idx = idx + block_size
        for i in self.done_idxs:
            if i > idx:  # first done_idx greater than idx
                done_idx = min(int(i), done_idx)
                break
        idx = done_idx - block_size
        states = torch.tensor(np.array(self.data[idx]), dtype=torch.float32)
        actions = torch.tensor(self.actions[idx], dtype=torch.long)
        rtgs = torch.tensor(self.rtgs[idx], dtype=torch.float32).unsqueeze(-1)
        timesteps = torch.tensor(self.timesteps[idx], dtype=torch.int64)

        return states, actions, rtgs, timesteps



class ReplayBuffer:
    def __init__(self, n_agents, buffer_size, context_length):
        self.n_agents = n_agents
        self.buffer_size = buffer_size
        self.context_length = context_length

        self.data = []
        self.episode = [[] for i in range(self.n_agents)]

    @property
    def cur_size(self):
        return len(self.data)

    def insert(self, global_obs, local_obs, action, reward, done, available_actions):
        for i in range(self.n_agents):
            step = [global_obs[0][i], local_obs[0][i], [action[i]], reward[0][i], done[0][i], available_actions[0][i]]
            self.episode[i].append(step)
        if np.all(done):
            if self.cur_size >= self.buffer_size:
                del_count = self.cur_size - self.buffer_size + 1
                del self.data[:del_count]
            self.data.append(self.episode.copy())
            self.episode = [[] for i in range(self.n_agents)]

    def reset(self):
        self.data = []

    # def load_offline_data(self, data_dir, episode_num, bias):
    #     for i in range(episode_num):
    #         idx = i + bias
    #         episode = torch.load(data_dir + str(idx))
    #         self.data.append(episode)

    # def sample(self, batch_size):
    #     ep_ids = np.random.choice(self.cur_size, batch_size, replace=False)
    #     batch = [self.data[idx] for idx in ep_ids]
    #     s, o, a, d_idx, r, t, s_next, o_next, ava_next = self.batch_processing(batch.copy())
    #     dataset = SequentialDataset(self.context_length, s, o, a, d_idx, r, t, s_next, o_next, ava_next)
    #     return dataset

    # def batch_processing(self, batch):
    #     global_states = [[] for i in range(self.n_agents)]
    #     local_obss = [[] for i in range(self.n_agents)]
    #     actions = [[] for i in range(self.n_agents)]
    #     rewards = [[] for i in range(self.n_agents)]
    #     done_idxs = [[] for i in range(self.n_agents)]
    #     time_steps = [[] for i in range(self.n_agents)]
    #     next_global_states = [[] for i in range(self.n_agents)]
    #     next_local_obss = [[] for i in range(self.n_agents)]
    #     next_available_actions = [[] for i in range(self.n_agents)]
    #
    #     for episode in batch:
    #         for j, agent_trajectory in enumerate(episode):
    #             time_step = 0
    #             for i in range(len(agent_trajectory)):
    #                 g, o, a, r, d, ava = agent_trajectory[i]
    #                 if i < len(agent_trajectory) - 1:
    #                     g_next = agent_trajectory[i + 1][0]
    #                     o_next = agent_trajectory[i + 1][1]
    #                     ava_next = agent_trajectory[i + 1][5]
    #                 else:
    #                     g_next = g
    #                     o_next = o
    #                     ava_next = ava
    #
    #                 global_states[j].append(g)
    #                 local_obss[j].append(o)
    #                 actions[j].append(a)
    #                 rewards[j].append(r[0])
    #                 time_steps[j].append(time_step)
    #                 time_step += 1
    #                 next_global_states[j].append(g_next)
    #                 next_local_obss[j].append(o_next)
    #                 next_available_actions[j].append(ava_next)
    #             done_idxs[j].append(len(global_states[j]))
    #
    #     actions = list2array(actions).swapaxes(1, 0).tolist()
    #     done_idxs = list2array(done_idxs).swapaxes(1, 0).tolist()
    #     rewards = list2array(rewards).swapaxes(1, 0).tolist()
    #     time_steps = list2array(time_steps).swapaxes(1, 0).tolist()
    #     next_available_actions = list2array(next_available_actions).swapaxes(1, 0).tolist()
    #     global_states = list2array(global_states).swapaxes(1, 0).tolist()
    #     local_obss = list2array(local_obss).swapaxes(1, 0).tolist()
    #     next_global_states = list2array(next_global_states).swapaxes(1, 0).tolist()
    #     next_local_obss = list2array(next_local_obss).swapaxes(1, 0).tolist()
    #
    #     # [s, o, a, d, r, t, s_next, o_next, ava_next]
    #     return global_states, local_obss, actions, done_idxs, rewards, time_steps, next_global_states, next_local_obss,\
    #            next_available_actions




