import numpy as np
from collections import namedtuple, deque
import random
import torch
from config import DTYPE, DEVICE


TransitionPrincipal = namedtuple('Transition',
                                 ('state', 'action', 'reward', 'cost', 'outcome', 'done', 'next_state'))

TransitionAgent = namedtuple('Transition',
                             ('state', 'action', 'reward', 'contract', 'done', 'next_state', 'next_contract'))


class ReplayMemoryContract(object):
    def __init__(self, capacity, agent=False):
        self.capacity = capacity
        self.reset(agent)

    def push(self, *args):
        transition = self.TransitionClass(*args)
        self.memory.append(transition)

    def sample(self, batch_size):
        transitions = random.sample(self.memory, batch_size)
        states = torch.stack([t.state for t in transitions], dim=0).to(DEVICE)  # (bs, *state.shape)
        actions = torch.stack([t.action for t in transitions], dim=0).to(DEVICE)  # (bs, 1)
        rewards = torch.from_numpy(np.array([t.reward for t in transitions])).to(DTYPE).to(DEVICE)  # (bs)
        dones = torch.LongTensor([t.done for t in transitions]).to(DEVICE)  # (bs,)
        next_states = torch.stack([t.next_state for t in transitions], dim=0).to(DEVICE)  # (bs, *state.shape)
        if not self.agent:
            costs = torch.FloatTensor([t.cost for t in transitions]).to(DTYPE).to(DEVICE)
            outcomes = torch.LongTensor([t.outcome for t in transitions]).unsqueeze(-1).to(DEVICE)
            return states, actions, rewards, costs, outcomes, dones, next_states
        else:
            contracts = torch.cat([t.contract for t in transitions], dim=0).to(DEVICE)
            next_contracts = torch.cat([t.next_contract for t in transitions], dim=0).to(DEVICE)
            return states, actions, rewards, contracts, dones, next_states, next_contracts

    def reset(self, agent=False):
        self.memory = deque([], maxlen=self.capacity)
        self.agent = agent
        self.TransitionClass = TransitionPrincipal if not agent else TransitionAgent

    def __len__(self):
        return len(self.memory)
