import numpy as np
import copy
import time
import torch as th


def padding_instance(reward_instance, latest_instance, total_length, exclude_last_instance):
    padding_length = total_length - len(reward_instance)
    if exclude_last_instance:
        start_position = -len(reward_instance) - 1
    else:
        start_position = -len(reward_instance)
    padding_instances = latest_instance[start_position - padding_length: start_position]
    if len(padding_instances) < padding_length:
        zero_instance = Instance()
        zero_instance.zeros_like(reward_instance[0])
        for i in range(padding_length - len(padding_instances)):
            padding_instances.insert(0, zero_instance)

    # padding instance do not compute gradient
    for index, item in enumerate(padding_instances):
        padding_instances[index].lstm_gradient_mask = 0
    for index, item in enumerate(reward_instance):
        reward_instance[index].lstm_gradient_mask = 1
    padding_instances.extend(reward_instance)
    return padding_instances


def to_device(data, device="cpu"):
    if isinstance(data, dict):
        result = dict()
        for key, value in data.items():
            result[key] = to_device(value, device)
    elif isinstance(data, (list, tuple)):
        result = [to_device(d, device) for d in data]
    else:
        result = th.from_numpy(data).to(device)
    return result


def get_batched_obs(obs_list):
    sample = obs_list[0]
    if isinstance(sample, dict):
        batched_obs = dict()
        for key in sample:
            batched_obs[key] = get_batched_obs([obs[key] for obs in obs_list])
    elif isinstance(sample, (list, tuple)):
        batched_obs = [get_batched_obs(o) for o in zip(*obs_list)]
    else:
        batched_obs = np.asarray(obs_list)
        if len(batched_obs.shape) == 1:
            batched_obs = np.expand_dims(batched_obs, -1)
    return batched_obs


class Instance:
    def __init__(
            self,
            data_time=None,
            state=None,
            style=None,
            action=None,
            old_log_prob=None,
            old_state_value=None,
            advantage=None,
            q_value=None,
            is_done=None,
            reward=None
    ):

        if data_time is None:
            self.data_time = time.time()
        else:
            self.data_time = data_time
        self.state = state
        self.style = style
        self.action = action
        self.old_log_prob = old_log_prob
        self.old_state_value = old_state_value
        self.advantage = advantage
        self.q_value = q_value
        self.is_done = is_done
        self.reward = reward

    def zeros_like(self, target_instance):
        self.dota_time = 0
        self.state_gf = np.zeros_like(target_instance.state_gf)


class TrainingSet:
    def __init__(
            self,
            max_capacity=10000
    ):
        self.max_capacity = max_capacity

        self.data_time_list = []
        self.state_list = []
        self.style_list = []
        self.action_list = []
        self.old_log_prob_list = []
        self.old_state_value_list = []
        self.advantage_list = []
        self.q_value_list = []

    def clear(self):
        self.data_time_list = []
        self.state_list = []
        self.style_list = []
        self.action_list = []
        self.old_log_prob_list = []
        self.old_state_value_list = []
        self.advantage_list = []
        self.q_value_list = []

    def len(self):
        return len(self.data_time_list)

    def fit_max_size(self):
        if len(self.data_time_list) > self.max_capacity:
            keep_index_start = len(self.data_time_list) - self.max_capacity

            self.data_time_list = self.data_time_list[keep_index_start:]
            self.state_list = self.state_list[keep_index_start:]
            self.style_list = self.style_list[keep_index_start:]
            self.action_list = self.action_list[keep_index_start:]
            self.old_log_prob_list = self.old_log_prob_list[keep_index_start:]
            self.old_state_value_list = self.old_state_value_list[keep_index_start:]
            self.advantage_list = self.advantage_list[keep_index_start:]
            self.q_value_list = self.q_value_list[keep_index_start:]

    def append_instance(self, instances):
        self.data_time_list.extend([[i.data_time] for i in instances])
        self.state_list.extend([[i.state] for i in instances])
        self.style_list.extend([i.style for i in instances])
        self.action_list.extend([i.action for i in instances])
        self.old_log_prob_list.extend([i.old_log_prob for i in instances])
        self.old_state_value_list.extend([i.old_state_value for i in instances])
        self.advantage_list.extend([[i.advantage] for i in instances])
        self.q_value_list.extend([[i.q_value] for i in instances])

    def slice(self, batch_size):
        index_list = self._generate_random_index(batch_size)
        slice_dict = {}
        slice_dict["states"] = np.array([self.state_list[i] for i in index_list]).reshape(batch_size, -1)
        slice_dict["styles"] = np.array([self.style_list[i] for i in index_list]).reshape(batch_size, -1)
        slice_dict["actions"] = get_batched_obs([self.action_list[i] for i in index_list])
        slice_dict["old_log_prob"] = get_batched_obs([self.old_log_prob_list[i] for i in index_list])
        slice_dict["old_state_value"] = np.array([self.old_state_value_list[i] for i in index_list]).reshape(batch_size, -1)
        slice_dict["advantages"] = np.array([self.advantage_list[i] for i in index_list]).reshape(batch_size, -1)
        slice_dict["q_values"] = np.array([self.q_value_list[i] for i in index_list]).reshape(batch_size, -1)
        return slice_dict

    def _generate_random_index(self, batch_size):
        return np.random.choice(range(self.len()), batch_size, replace=False)
