import numpy as np
import torch
import utils


class ExpBuffer:
    def __init__(self, max_len, state_dim, action_dim, agent_num, args):
        self.agent_num = agent_num
        self.use_prior = True if args.algo == "PPOwPrior" else False
        self.use_state_norm = args.use_state_norm
        self.use_local_obs = args.env.use_local_obs
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.max_len = max_len
        self.now_len = 0
        if self.use_local_obs:
            self.state = [torch.empty((max_len, state_dim), dtype=torch.float32, device=self.device) for _ in range(agent_num)]
            self.state_rms = [utils.RunningMeanStd(shape=(state_dim,)) for _ in range(agent_num)]
            self.state_after = [torch.empty((max_len, state_dim), dtype=torch.float32, device=self.device) for _ in range(agent_num)]
            self.mask = [torch.ones((max_len, action_dim), dtype=torch.bool, device=self.device) for _ in range(agent_num)]
        else:
            self.state = torch.empty((max_len, state_dim), dtype=torch.float32, device=self.device)
            self.state_rms = utils.RunningMeanStd(shape=(state_dim,))
            self.state_after = torch.empty((max_len, state_dim), dtype=torch.float32, device=self.device)
            self.mask = torch.ones((max_len, action_dim), dtype=torch.bool, device=self.device)
        self.action = torch.empty((max_len, agent_num), dtype=torch.float32, device=self.device)
        self.reward = torch.empty((max_len, agent_num + agent_num * self.use_prior), dtype=torch.float32, device=self.device)
        self.done = torch.empty((max_len, 1), dtype=torch.float32, device=self.device)

    def append(self, state, action, state_after, reward, done, mask=None):
        if self.now_len >= self.max_len:
            return
        if self.use_local_obs:
            for aid in range(self.agent_num):
                self.state[aid][self.now_len] = torch.as_tensor(state[aid], device=self.device)
                self.state_after[aid][self.now_len] = torch.as_tensor(state_after[aid], device=self.device)
                if mask is not None:
                    self.mask[aid][self.now_len] = torch.as_tensor(mask[aid], device=self.device)
        else:
            self.state[self.now_len] = torch.as_tensor(state, device=self.device).flatten()
            self.state_after[self.now_len] = torch.as_tensor(state_after, device=self.device).flatten()
            if mask is not None:
                self.mask[self.now_len] = torch.as_tensor(mask, device=self.device)
        self.action[self.now_len] = torch.as_tensor(action, device=self.device).flatten()
        self.reward[self.now_len] = torch.as_tensor(reward, device=self.device).flatten()
        self.done[self.now_len] = torch.as_tensor(done, device=self.device)

        self.now_len += 1

    def update_rms(self):
        if self.use_local_obs:
            for aid in range(self.agent_num):
                self.state_rms[aid].update(self.state[aid][:self.now_len])
                print("agent " + str(aid) + ": state mean:", self.state_rms.mean.cpu().numpy(), ",  state variance: ",
                      self.state_rms.var.cpu().numpy())
        else:
            self.state_rms.update(self.state[:self.now_len])
            print("state mean:", self.state_rms.mean.cpu().numpy(), ",  state variance: ", self.state_rms.var.cpu().numpy())

    def normalize_obs(self, state):
        """
        Normalize observations using this VecNormalize's observations statistics.
        Calling this method does not update statistics.
        """
        epsilon = 1e-8
        if self.use_local_obs:
            for aid in range(self.agent_num):
                state[aid] = (state[aid] - self.state_rms[aid].mean) / torch.sqrt(self.state_rms[aid].var + epsilon)
        else:
            state = (state - self.state_rms.mean) / torch.sqrt(self.state_rms.var + epsilon)
        return state

    def sample_batch(self, batch_size):
        indices = torch.randint(self.now_len - 1, size=(batch_size,), device=self.device)
        if self.use_local_obs:
            state = [self.state[aid][indices] for aid in range(self.agent_num)]
        else:
            state = self.state[indices]
        if self.use_state_norm:
            state = self.normalize_obs(state)
        return state, self.reward[indices], self.action[indices], self.done[indices]

    def sample_all(self):
        if self.use_local_obs:
            state = [self.state[aid][:self.now_len] for aid in range(self.agent_num)]
            state_after = [self.state_after[aid][:self.now_len] for aid in range(self.agent_num)]
            mask = [self.mask[aid][:self.now_len] for aid in range(self.agent_num)]
        else:
            state = self.state[:self.now_len]
            state_after = self.state_after[:self.now_len]
            mask = self.mask[:self.now_len]
        if self.use_state_norm:
            state = self.normalize_obs(state)
        return state, self.reward[:self.now_len], self.action[:self.now_len], state_after, self.done[:self.now_len], mask

    def empty_buffer_before_explore(self):
        self.now_len = 0

