import torch
import numpy as np

class PriorBuffer:
    def __init__(self, capacity, dim_s, max_arm_step, decay=0.99, device="cpu", dtype=torch.float32):
        self.capacity = capacity
        self.dim_s = dim_s
        self.device = device
        self.dtype = dtype
        self.decay = decay
        self.max_arm_step = max_arm_step
        self.size = 0
        self.s_buf = torch.zeros((capacity, dim_s), dtype=dtype, device=device)
        self.a_buf = [None] * capacity
        self.trgr_buf = torch.zeros((capacity, dim_s + 1), dtype=dtype, device=device)
        self.w_buf = torch.zeros((capacity, 1), dtype=dtype, device=device)
        
    @torch.no_grad()
    def add(self, temp):
        s, a, trgr, w = temp
        batch = s.size(0)
        for i in range(batch):
            if self.size < self.capacity:
                idx = self.size
                self.size += 1
            else:
                self.w_buf = self.w_buf * self.decay
                min_w, min_idx = torch.min(self.w_buf[:self.size], dim=0)
                if w[i] <= min_w:
                    continue
                idx = min_idx.item()

            self.s_buf[idx] = s[i]
            self.w_buf[idx] = w[i]
            self.a_buf[idx] = a[i].to(self.device)
            self.trgr_buf[idx] = trgr[i].to(self.device)

    @torch.no_grad()
    def get(self, sort_by_weight=True, descending=True):
        valid_idx = slice(0, self.size)
        s = self.s_buf[valid_idx]
        w = self.w_buf[valid_idx]
        a = self.a_buf[:self.size]
        trgr = self.trgr_buf[:self.size]

        if sort_by_weight:
            sorted_w, sorted_idx = torch.sort(w.squeeze(), descending=descending)
            s = s[sorted_idx]
            w = w[sorted_idx]
            trgr = trgr[sorted_idx]
            a = [a[i] for i in sorted_idx.tolist()]
        return s, a, trgr, w

    @torch.no_grad()
    def sample(self, batch_size, nstep):
        """  
        sample 'batch_size' number of {nstep} data
        """
        s, a, trgt, w = self.get()
        a_temp = []
        s_k, a_k, trgt_k = [], [], []
        cnt_n = [0]* self.max_arm_step
        idx_nstep = [False] * len(a)
        for i in range(len(a)):
            action_seq_len = a[i].shape[0]
            if action_seq_len == nstep:
                a_temp.append(a[i])
                idx_nstep[i] = True
            # record the count of different nstep
            for j in range(self.max_arm_step):
                if (action_seq_len-1)==j: cnt_n[j] +=1
        
        # beta-sample
        n = len(a_temp)
        if n > batch_size:
            b = 3
            m = batch_size
            ranks = np.arange(1, n + 1)
            weights = ((n - ranks + 1) / n) ** (b - 1)
            p = weights / weights.sum()
            idx = np.random.choice(n, size=m, replace=False, p=p)
        else: idx = np.arange(n)
        if a_temp: 
            a_k = torch.stack(a_temp, dim=0)[idx]
            s_k = s[idx_nstep][idx]
            trgt_k = trgt[idx_nstep][idx]
        return  s_k, a_k, trgt_k
    
    @torch.no_grad()
    def reset(self):
        self.size = 0
        self.s_buf = torch.zeros((self.capacity, self.dim_s), dtype=self.dtype, device=self.device)
        self.a_buf = [None] * self.capacity
        self.trgr_buf = torch.zeros((self.capacity, self.dim_s + 1), dtype=self.dtype, device=self.device)
        self.w_buf = torch.zeros((self.capacity, 1), dtype=self.dtype, device=self.device)