# --coding:utf-8--
import itertools
import random
import numpy as np
import torch

import ray
from collections import namedtuple

from gv.generalized_variance import *


# Transition = namedtuple('Transition', ('state_action', 'flag'))


# @ray.remote，表明该类可以被ray远程操作
@ray.remote
class ReplayBuffer:

    def __init__(self, policy_number, transformer):
        self.size = 0
        self.episode = 0
        self.policy_number = policy_number

        self.transformer = transformer
        for p in self.transformer.parameters():
            p.requires_grad = False

        self.transformer_training = False

        self.memory = []

        for i in range(policy_number):
            self.memory.append([])

    def push(self, flag, transition):
        self.memory[flag].append(transition)

    def get_size(self):
        temp = 0
        for i in range(self.policy_number):
            temp += len(self.memory[i])
        self.size = temp
        return self.size

    # def count_size(self, count):
    #     self.size = self.size + count

    def buffer_reset(self):
        self.size = 0
        self.memory.clear()

    def extract_data(self):
        return [list(self.memory[i]) for i in range(self.policy_number)]

    #     return Transition(*zip(*self.get_all()))

    def get_transformer_training_bool(self):
        return self.transformer_training

    def change_transformer_training_bool(self, training_flag):
        self.transformer_training = training_flag

    def get_all(self, flag):
        return self.get(0, len(self.memory[flag]))

    def get(self, flag, start_idx: int, end_idx: int):
        transitions = list(itertools.islice(self.memory[flag], start_idx, end_idx))
        return transitions

    def update_transformer(self, model):
        self.transformer = model
        for p in self.transformer.parameters():
            p.requires_grad = False

    # def count_episode(self):
    #     self.episode += 1
    #
    # def get_total_episode(self):
    #     return self.episode

    # 在replaybuffer中为n个agent计算共同的GV
    def get_generalized_variance(self, state_batch, action_batch, flag):

        # 因为每个agent是一个进程,进程采样进度不一致,所以刚开始执行时,有进程未将data存入common_replay_buffer,所以这里要进行判断,这种情况下直接返回diversity值为0
        a = np.array([])

        #============================================================================
        # 该循环在调试时使用
        # for i in range(self.policy_number):
        #     self.memory[i] = self.memory[0]
        #============================================================================

        for i in range(self.policy_number):
            a = np.append(a, len(self.memory[i]))

        if np.count_nonzero(a) < self.policy_number:
            print("Agent:{} get GV=0".format(flag))
            return 0
        # a是每个agent存入数据的条目数的数组
        index_range = int(a[np.argsort(a)[0]])  # argsort 将矩阵a按照axis排序,并返回排序后的下标参数
        index = random.randint(0, index_range-1)  # index为选择哪条trajectory,每个agent都选同index这条

        traj = []
        for i in range(action_batch.shape[0]):
            traj.append(torch.cat([state_batch[i], action_batch[i]], dim=0))
        traj = torch.cat(traj)

        matrix = torch.FloatTensor([list(self.memory[j][index][:-1]) for j in range(self.policy_number)])

        matrix[flag] = traj  # 这里替换一条traj是为了让梯度能回传
        # transformer前推带transformer的梯度,似乎也没有问题,因为计算图在transformer这里传不回去
        # with torch.no_grad():
        pop = self.transformer.representation_forward(matrix)

        # pop = np.array([list(self.memory[i][index][:-1]) for i in range(self.policy_number)])
        # return compute_generalized_variance(pop)
        return torch_generalized_variance(pop, flag)
