# -*- coding: utf-8 -*-
from __future__ import division
from collections import namedtuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

Transition = namedtuple('Transition', ('timestep', 'state', 'action', 'reward', 'nonterminal'))
blank_trans = Transition(0, torch.zeros(84, 84, dtype=torch.uint8), None, 0, False)


# Segment tree data structure where parent node values are sum/max of children node values
class SegmentTree():
  def __init__(self, size):
    self.index = 0
    self.size = size
    self.full = False  # Used to track actual capacity
    self.sum_tree = np.zeros((2 * size - 1, ), dtype=np.float32)  # Initialise fixed size tree with all (priority) zeros
    self.data = np.array([None] * size)  # Wrap-around cyclic buffer
    self.max = 1  # Initial max value to return (1 = 1^ω)

  # Propagates value up tree given a tree index
  def _propagate(self, index, value):
    parent = (index - 1) // 2
    left, right = 2 * parent + 1, 2 * parent + 2
    self.sum_tree[parent] = self.sum_tree[left] + self.sum_tree[right]
    if parent != 0:
      self._propagate(parent, value)

  # Updates value given a tree index
  def update(self, index, value):
    self.sum_tree[index] = value  # Set new value
    self._propagate(index, value)  # Propagate value
    self.max = max(value, self.max)

  def append(self, data, value):
    self.data[self.index] = data  # Store data in underlying data structure
    self.update(self.index + self.size - 1, value)  # Update tree
    self.index = (self.index + 1) % self.size  # Update index
    self.full = self.full or self.index == 0  # Save when capacity reached
    self.max = max(value, self.max)

  # Searches for the location of a value in sum tree
  def _retrieve(self, index, value):
    left, right = 2 * index + 1, 2 * index + 2
    if left >= len(self.sum_tree):
      return index
    elif value <= self.sum_tree[left]:
      return self._retrieve(left, value)
    else:
      return self._retrieve(right, value - self.sum_tree[left])

  # Searches for a value in sum tree and returns value, data index and tree index
  def find(self, value):
    index = self._retrieve(0, value)  # Search for index of item from root
    data_index = index - self.size + 1
    return (self.sum_tree[index], data_index, index)  # Return value, data index, tree index

  # Returns data given a data index
  def get(self, data_index):
    return self.data[data_index % self.size]

  def total(self):
    return self.sum_tree[0]


class ReplayMemory():
  def __init__(self, args, capacity):
    self.args=args
    self.device = args.device
    self.capacity = capacity
    self.history = args.history_length
    self.discount = args.discount
    self.n = args.multi_step
    self.priority_weight = args.priority_weight  # Initial importance sampling weight β, annealed to 1 over course of training
    self.priority_exponent = args.priority_exponent
    self.t = 0  # Internal episode timestep counter
    self.transitions = SegmentTree(capacity)  # Store transitions in a wrap-around cyclic buffer within a sum tree for querying priorities
    self.td = np.zeros((capacity), dtype=np.float32)
    self.q = np.zeros((capacity), dtype=np.float32)
    self.prev_score = np.inf
    self.curr_score = np.inf
    self.reward = 0.0
    #self.td = np.zeros((self.capacity), dtype=np.float32)
  # Adds state and action at time t, reward and terminal at time t + 1
  def append(self, state, action, reward, terminal):
    #self.td[self.transitions.index] = 1

    state = state[-1].mul(255).to(dtype=torch.uint8, device=torch.device('cpu'))  # Only store last frame and discretise to save memory
    self.td[self.transitions.index] = 1.0
    self.q[self.transitions.index] = 1.0
    self.transitions.append(Transition(self.t, state, action, reward, not terminal), self.transitions.max)  # Store new transition with maximum priority
    self.t = 0 if terminal else self.t + 1  # Start new episodes with t = 0



  # Returns a transition with blank states where appropriate
  def _get_transition(self, idx):
    transition = np.array([None] * (self.history + self.n))
    transition[self.history - 1] = self.transitions.get(idx)
    for t in range(self.history - 2, -1, -1):  # e.g. 2 1 0
      if transition[t + 1].timestep == 0:
        transition[t] = blank_trans  # If future frame has timestep 0
      else:
        transition[t] = self.transitions.get(idx - self.history + 1 + t)
    for t in range(self.history, self.history + self.n):  # e.g. 4 5 6
      if transition[t - 1].nonterminal:
        transition[t] = self.transitions.get(idx - self.history + 1 + t)
      else:
        transition[t] = blank_trans  # If prev (next) frame is terminal
    return transition

  # Returns a valid sample from a segment
  def _get_sample_from_segment(self, segment, i):
    valid = False
    while not valid:
      sample = np.random.uniform(i * segment, (i + 1) * segment)  # Uniformly sample an element from within a segment
      prob, idx, tree_idx = self.transitions.find(sample)  # Retrieve sample from tree with un-normalised probability
      # Resample if transition straddled current index or probablity 0
      if (self.transitions.index - idx) % self.capacity > self.n and (idx - self.transitions.index) % self.capacity >= self.history and prob != 0:
        valid = True  # Note that conditions are valid but extra conservative around buffer index 0

    # Retrieve all required transition data (from t - h to t + n)
    transition = self._get_transition(idx)
    # Create un-discretised state and nth next state
    state = torch.stack([trans.state for trans in transition[:self.history]]).to(device=self.device).to(dtype=torch.float32).div_(255)
    next_state = torch.stack([trans.state for trans in transition[self.n:self.n + self.history]]).to(device=self.device).to(dtype=torch.float32).div_(255)
    # Discrete action to be used as index
    action = torch.tensor([transition[self.history - 1].action], dtype=torch.int64, device=self.device)
    # Calculate truncated n-step discounted return R^n = Σ_k=0->n-1 (γ^k)R_t+k+1 (note that invalid nth next states have reward 0)
    R = torch.tensor([sum(self.discount ** n * transition[self.history + n - 1].reward for n in range(self.n))], dtype=torch.float32, device=self.device)
    # Mask for non-terminal nth next states
    nonterminal = torch.tensor([transition[self.history + self.n - 1].nonterminal], dtype=torch.float32, device=self.device)

    return prob, idx, tree_idx, state, action, R, next_state, nonterminal

  def sample(self, batch_size):
    p_total = self.transitions.total()  # Retrieve sum of all priorities (used to create a normalised probability distribution)
    segment = p_total / batch_size  # Batch size number of segments, based on sum over all probabilities
    batch = [self._get_sample_from_segment(segment, i) for i in range(batch_size)]  # Get batch of valid samples
    probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = zip(*batch)
    states, next_states, = torch.stack(states), torch.stack(next_states)
    actions, returns, nonterminals = torch.cat(actions), torch.cat(returns), torch.stack(nonterminals)
    probs = np.array(probs, dtype=np.float32) / p_total  # Calculate normalised probabilities
    capacity = self.capacity if self.transitions.full else self.transitions.index
    weights = (capacity * probs) ** -self.priority_weight  # Compute importance-sampling weights w
    weights = torch.tensor(weights / weights.max(), dtype=torch.float32, device=self.device)  # Normalise by max importance-sampling weight from batch
    return tree_idxs, states, actions, returns, next_states, nonterminals, weights


  def update_priorities(self, idxs, td, q):
    priorities = np.power(td, self.priority_exponent)
    [self.transitions.update(idx, priority) for idx, priority in zip(idxs, priorities)]

  # Set up internal state for iterator
  def __iter__(self):
    self.current_idx = 0
    return self

  # Return valid states for validation
  def __next__(self):
    if self.current_idx == self.capacity:
      raise StopIteration
    # Create stack of states
    state_stack = [None] * self.history
    state_stack[-1] = self.transitions.data[self.current_idx].state
    prev_timestep = self.transitions.data[self.current_idx].timestep
    for t in reversed(range(self.history - 1)):
      if prev_timestep == 0:
        state_stack[t] = blank_trans.state  # If future frame has timestep 0
      else:
        state_stack[t] = self.transitions.data[self.current_idx + t - self.history + 1].state
        prev_timestep -= 1
    state = torch.stack(state_stack, 0).to(dtype=torch.float32, device=self.device).div_(255)  # Agent will turn into batch
    self.current_idx += 1
    return state

  next = __next__  # Alias __next__ for Python 2 compatibility




class NERS(nn.Module, ReplayMemory):
  def __init__(self, args, capacity):
    nn.Module.__init__(self)
    ReplayMemory.__init__(self, args, capacity)
    self.cnt=0
    self.index_for_train = np.zeros((0),dtype=np.int)
  def make_network(self,state):
    self.convs = nn.Sequential(nn.Conv2d(self.args.history_length, 32, 5, stride=5, padding=0), nn.ReLU(),
                               nn.Conv2d(32, 64, 5, stride=5, padding=0), nn.ReLU(), nn.Flatten()).to(self.device)
    input = torch.unsqueeze(state.to(device=self.device),0)
    output = self.convs(input)
    hiddensize = 128
    self.feature_net2 = nn.Sequential(nn.Linear(output.view(-1).shape[0], 256), nn.ReLU(),
                                      nn.Linear(256, 64), nn.ReLU(),
                                      nn.Linear(64, 32), nn.ReLU()).to(device=self.device)
    self.input_size = 32 * 2 + 1 + 1 + 1 + 1 + 1  # statex2 + action + reward + td + qval + timestep
    self.global_net = nn.Sequential(nn.Linear(self.input_size, hiddensize), nn.ReLU(),
                                    nn.Linear(hiddensize, hiddensize * 2), nn.ReLU(),
                                    nn.Linear(hiddensize * 2, hiddensize), nn.ReLU(),
                                    nn.Linear(hiddensize, hiddensize // 2)).to(device=self.device)
    self.local_net = nn.Sequential(nn.Linear(self.input_size, hiddensize), nn.ReLU(),
                                   nn.Linear(hiddensize, hiddensize * 2), nn.ReLU(),
                                   nn.Linear(hiddensize * 2, hiddensize), nn.ReLU(),
                                   nn.Linear(hiddensize, hiddensize // 2)).to(device=self.device)
    self.score_net = nn.Sequential(nn.Linear(2 * (hiddensize // 2), hiddensize), nn.ReLU(),
                                   nn.Linear(hiddensize, hiddensize // 2), nn.ReLU(),
                                   nn.Linear(hiddensize // 2, 1), nn.Sigmoid()).to(device=self.device)
    self.optim = optim.Adam(
      self.parameters(),
      lr=0.0001,
      weight_decay=0.0,
    )


  # Adds state and action at time t, reward and terminal at time t + 1
  def append(self, state, action, reward, terminal):
    if self.cnt == 0:
        self.make_network(state)
    self.cnt += 1
    ReplayMemory.append(self, state, action, reward, terminal)

  def update_priorities(self, idxs, td, q):
    idxes = np.array(idxs) - self.capacity + 1
    self.index_for_train = np.concatenate((self.index_for_train, idxes), 0)
    for i, idx in enumerate(idxes):
      self.td[idx] = np.tanh(td[i])
      self.q[idx] = np.tanh(q[i])
    with torch.no_grad():
      priorities = self.get_prob_by_idxs(idxes)
      priorities = priorities.detach().cpu().numpy()

    priorities = np.power(priorities, self.priority_exponent)
    [self.transitions.update(idx, priority) for idx, priority in zip(idxs, priorities)]

  def get_prob_by_idxs(self, indices):
    curr_states = []
    next_states = []
    actions = []
    rewards = []
    timesteps = []
    tds = []
    qs = []
    for idx in indices:
      transition = self._get_transition(idx)
      timesteps.append(transition[0].timestep)
      state = torch.stack([trans.state for trans in transition[:self.history]]).to(device=self.device).to(
        dtype=torch.float32).div_(255)
      next_state = torch.stack([trans.state for trans in transition[self.n:self.n + self.history]]).to(
        device=self.device).to(dtype=torch.float32).div_(255)
      # Discrete action to be used as index
      action = torch.tensor([transition[self.history - 1].action], dtype=torch.int64, device=self.device)
      # Calculate truncated n-step discounted return R^n = Σ_k=0->n-1 (γ^k)R_t+k+1 (note that invalid nth next states have reward 0)
      R = torch.tensor([sum(self.discount ** n * transition[self.history + n - 1].reward for n in range(self.n))],
                       dtype=torch.float32, device=self.device)
      curr_states.append(state)
      next_states.append(next_state)
      actions.append(action)
      rewards.append(R)
      tds.append(self.td[idx])
      qs.append(self.q[idx])
    curr_states = torch.stack(curr_states).type(torch.FloatTensor).to(device=self.device)
    next_states = torch.stack(next_states).type(torch.FloatTensor).to(device=self.device)
    actions = torch.stack(actions).reshape(-1,1).type(torch.FloatTensor).to(device=self.device)
    rewards = torch.stack(rewards).reshape(-1,1).type(torch.FloatTensor).to(device=self.device)
    timesteps = torch.Tensor(timesteps).reshape(-1,1).type(torch.FloatTensor).to(device=self.device)
    tds = torch.Tensor(tds).reshape(-1,1).type(torch.FloatTensor).to(device=self.device)
    qs = torch.Tensor(qs).reshape(-1, 1).type(torch.FloatTensor).to(device=self.device)
    curr = self.feature_net2(self.convs(curr_states)).to(device=self.device)
    next = self.feature_net2(self.convs(next_states)).to(device=self.device)
    tensor_input = torch.cat(
    (curr, actions, rewards, next, tds, qs, timesteps),
    -1).to(device=self.device)
    tensor_input = torch.tanh(tensor_input)
    global_out = self.global_net(tensor_input)
    local_out = self.local_net(tensor_input)
    global_out_mean = torch.mean(global_out, 0)
    res_out = torch.cat([local_out, torch.unsqueeze(global_out_mean, 0).expand_as(local_out)], -1)
    prob = self.score_net(res_out).reshape(-1)

    return prob

  def learn(self, avg_reward, avg_Q):
    if self.prev_score == np.inf:
      self.prev_score = avg_reward
    else:
      self.curr_score = avg_reward
      self.return_val = (self.curr_score - self.prev_score)
      idxes = self.index_for_train[np.random.choice(len(self.index_for_train), self.args.batch_size)]
      prob = self.get_prob_by_idxs(idxes) + 1e-6
      loss = -torch.mean(torch.log(prob) * self.return_val)
      self.optim.zero_grad()
      loss.backward()
      self.optim.step()
      self.index_for_train = np.zeros((0), dtype=np.int)

