import argparse
import torch
import numpy as np

import random
import os

import ecos
from scipy.sparse import csr_matrix

#from envs.Lock_batch import LockBatch
from envs.block_arb_mdp import BlockGame
from envs.block_arb_mdp_gensum import BlockGameGenSum

#from envs.block_arb_stamdp import BlockGame

def parse_args():
    parser = argparse.ArgumentParser()

    

    parser.add_argument('--exp_name', default="test", type=str)
    parser.add_argument('--num_threads', default=10, type=int)
    parser.add_argument('--update_frequency', default=1, type=int)

    parser.add_argument('--temp_path', default="temp", type=str)

    parser.add_argument('--recent_size', default=50000, type=int)
    parser.add_argument('--lsvi_recent_size', default=50000, type=int)
    parser.add_argument('--load', default=False, type=bool)
    parser.add_argument('--dense', default=False, type=bool)

    parser.add_argument('--seed', default=123, type=int)
    parser.add_argument('--env_seed', default=123, type=int)
    parser.add_argument('--num_warm_start', default=10, type=int)
    parser.add_argument('--num_episodes', default=1e5, type=int)
    parser.add_argument('--batch_size', default=512, type=int)

    #environment
    parser.add_argument('--num_envs', default=50, type=int)
    parser.add_argument('--num_players', default=2, type=int)
    parser.add_argument('--state_dim', default=3, type=int)
    parser.add_argument('--horizon', default=10, type=int)
    parser.add_argument('--num_actions', default=3, type=int)
    parser.add_argument('--observation_noise', default=0.1, type=float)

    #rep
    parser.add_argument('--rep_num_update', default=30, type=int)
    parser.add_argument('--rep_num_feature_update', default=64, type=int)
    parser.add_argument('--rep_num_adv_update', default=64, type=int)
    parser.add_argument('--discriminator_lr', default=1e-2, type=float)
    parser.add_argument('--discriminator_beta', default=0.9, type=float)
    parser.add_argument('--feature_lr', default=1e-2, type=float)
    parser.add_argument('--feature_beta', default=0.9, type=float)
    parser.add_argument('--linear_lr', default=1e-2, type=float)
    parser.add_argument('--linear_beta', default=0.9, type=float)
    parser.add_argument('--rep_lamb', default=0.01, type=float)
    parser.add_argument('--hidden_dim', default=256, type=int)
    parser.add_argument('--temperature', default=1, type=float)
    parser.add_argument('--phi0_temperature', default=1, type=float)

    parser.add_argument('--reuse_weights', default=True, type=bool)
    parser.add_argument('--optimizer', default='sgd', type=str)

    parser.add_argument('--softmax', default='vanilla', type=str)

    #lsvi
    parser.add_argument('--alpha', default=0.1, type=float)
    parser.add_argument('--lsvi_lamb', default=1, type=float)

    #eval
    parser.add_argument('--num_eval', default=100, type=int)

    args = parser.parse_args()
    return args

def make_batch_env(args, gensum=False):
    env = BlockGameGenSum() if gensum else BlockGame()
    env.init(horizon=args.horizon, 
             action_dim=args.num_actions, 
             state_dim=args.state_dim,
             noise=args.observation_noise,
             num_envs=args.num_envs,
             verbose=True,
             rseed=args.env_seed)

    env.seed(args.seed)
    env.action_space.seed(args.seed)

    eval_env = BlockGameGenSum() if gensum else BlockGame()
    eval_env.init(horizon=args.horizon, 
                action_dim=args.num_actions, 
                state_dim=args.state_dim,
                noise=args.observation_noise,
                num_envs=args.num_eval,
                rseed=args.env_seed)

    eval_env.seed(args.seed)
    eval_env.trans_prob_matrices = env.trans_prob_matrices
    if gensum:
        eval_env.reward_matrices1 = env.reward_matrices1
        eval_env.reward_matrices2 = env.reward_matrices2
    else:
        eval_env.reward_matrices = env.reward_matrices

    return env, eval_env

def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    #if torch.cuda.is_available():
    #    torch.cuda.manual_seed_all(seed)

    np.random.seed(seed)
    random.seed(seed)

'''
fast sampling. credit: https://stackoverflow.com/questions/34187130/fast-random-weighted-selection-across-all-rows-of-a-stochastic-matrix/34190035
'''
def sample(prob_matrix, items, n):

    cdf = np.cumsum(prob_matrix, axis=1)
    # random numbers are expensive, so we'll get all of them at once
    ridx = np.random.random(size=n)
    # the one loop we can't avoid, made as simple as possible
    idx = np.zeros(n, dtype=int)
    for i, r in enumerate(ridx):
        idx[i] = np.searchsorted(cdf[i], r)
    # fancy indexing all at once is faster than indexing in a loop
    #print(idx)
    idx = np.minimum(idx, len(items)-1)
    
    return items[idx]


class Buffer(object):
    def __init__(self, num_actions):
        
        self.num_actions = num_actions
        self.obses = []
        self.next_obses = []
        self.actions = []
        self.rewards = []
        self.idx = 0

    def add(self, obs, action, reward, next_obs):
        self.obses.append(obs)
        aoh = np.zeros(self.num_actions)
        aoh[action] = 1
        self.actions.append(aoh)
        self.rewards.append(reward)
        self.next_obses.append(next_obs)

        self.idx += 1

    def get_batch(self):
        return self.idx, np.array(self.obses), np.array(self.actions), np.array(self.rewards), np.array(self.next_obses) 

    def get(self, h):
        return self.obses[h], self.actions[h], self.rewards[h], self.next_obses[h]


class ReplayBuffer(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, num_actions, num_players, capacity, batch_size, device, recent_size=0):
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device
        self.num_actions = num_actions
        self.num_players = num_players

        self.obses = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.actions = np.empty((capacity, num_actions ** num_players), dtype=np.int)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)

        self.recent_size = recent_size

        self.idx = 0
        self.last_save = 0
        self.full = False

    def add(self, obs, action, reward, next_obs):
        np.copyto(self.obses[self.idx], obs)
        aoh = np.zeros((self.num_actions ** self.num_players), dtype=np.int)
        action = np.array(action) + 1
        aoh[np.prod(action)-1] = 1
        np.copyto(self.actions[self.idx], aoh)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add_batch(self, obs, action, reward, next_obs, size):
        np.copyto(self.obses[self.idx:self.idx+size], obs)
        aoh = np.zeros((size,self.num_actions ** self.num_players), dtype=np.int)
        a = [action[n][0]*self.num_actions+action[n][1] for n in range(size)]
        aoh[np.arange(size), a] = 1
        np.copyto(self.actions[self.idx:self.idx+size], aoh)
        np.copyto(self.rewards[self.idx:self.idx+size], reward)
        np.copyto(self.next_obses[self.idx:self.idx+size], next_obs)

        self.idx = (self.idx + size) % self.capacity
        self.full = self.full or self.idx == 0

    def add_from_buffer(self, buf, h):
        obs, action, reward, next_obs = buf.get(h)
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def get_full(self, recent_size=0, device=None):

        if device is None:
            device = self.device

        if self.idx <= recent_size or recent_size == 0: 
            start_index = 0
        else:
            start_index = self.idx - recent_size

        if self.full:
            obses = torch.as_tensor(self.obses[start_index:], device=device)
            actions = torch.as_tensor(self.actions[start_index:], device=device)
            rewards = torch.as_tensor(self.rewards[start_index:], device=device)
            next_obses = torch.as_tensor(self.next_obses[start_index:], device=device)
            
            return obses, actions, rewards, next_obses
                
        else:
            obses = torch.as_tensor(self.obses[start_index:self.idx], device=device)
            actions = torch.as_tensor(self.actions[start_index:self.idx], device=device)
            rewards = torch.as_tensor(self.rewards[start_index:self.idx], device=device)
            next_obses = torch.as_tensor(self.next_obses[start_index:self.idx], device=device)
                
            return obses, actions, rewards, next_obses

    def sample(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size

        if self.recent_size == 0 or self.idx < self.recent_size: 
            idxs = np.random.randint(
                0, self.capacity if self.full else self.idx, size=self.batch_size 
            )
        else:
            idxs = np.random.randint(
                self.idx - self.recent_size, self.capacity if self.full else self.idx, size=self.batch_size 
            )


        obses = torch.as_tensor(self.obses[idxs], device=self.device)
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device)
        
        return obses, actions, rewards, next_obses

        
    def save(self, save_dir):
        if self.idx == self.last_save:
            return
        path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
        payload = [
            self.obses[self.last_save:self.idx],
            self.next_obses[self.last_save:self.idx],
            self.actions[self.last_save:self.idx],
            self.rewards[self.last_save:self.idx]
        ]
        self.last_save = self.idx
        torch.save(payload, path)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chucks:
            start, end = [int(x) for x in chunk.split('.')[0].split('_')]
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.obses[start:end] = payload[0]
            self.next_obses[start:end] = payload[1]
            self.actions[start:end] = payload[2]
            self.rewards[start:end] = payload[3]
            self.idx = end

class ReplayBufferGenSum(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, num_actions, num_players, capacity, batch_size, device, recent_size=0):
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device
        self.num_actions = num_actions
        self.num_players = num_players

        self.obses = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.actions = np.empty((capacity, num_actions ** num_players), dtype=np.int)
        self.rewards1 = np.empty((capacity, 1), dtype=np.float32)
        self.rewards2 = np.empty((capacity, 1), dtype=np.float32)

        self.recent_size = recent_size

        self.idx = 0
        self.last_save = 0
        self.full = False

    def add(self, obs, action, reward1, reward2, next_obs):
        np.copyto(self.obses[self.idx], obs)
        aoh = np.zeros((self.num_actions ** self.num_players), dtype=np.int)
        action = np.array(action) + 1
        aoh[np.prod(action)-1] = 1
        np.copyto(self.actions[self.idx], aoh)
        np.copyto(self.rewards1[self.idx], reward1)
        np.copyto(self.rewards2[self.idx], reward2)
        np.copyto(self.next_obses[self.idx], next_obs)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add_batch(self, obs, action, reward1, reward2, next_obs, size):
        np.copyto(self.obses[self.idx:self.idx+size], obs)
        aoh = np.zeros((size,self.num_actions ** self.num_players), dtype=np.int)
        a = [action[n][0]*self.num_actions+action[n][1] for n in range(size)]
        aoh[np.arange(size), a] = 1
        np.copyto(self.actions[self.idx:self.idx+size], aoh)
        np.copyto(self.rewards1[self.idx:self.idx+size], reward1)
        np.copyto(self.rewards2[self.idx:self.idx+size], reward2)
        np.copyto(self.next_obses[self.idx:self.idx+size], next_obs)

        self.idx = (self.idx + size) % self.capacity
        self.full = self.full or self.idx == 0

    def add_from_buffer(self, buf, h):
        obs, action, reward, next_obs = buf.get(h)
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def get_full(self, recent_size=0, device=None):

        if device is None:
            device = self.device

        if self.idx <= recent_size or recent_size == 0: 
            start_index = 0
        else:
            start_index = self.idx - recent_size

        if self.full:
            obses = torch.as_tensor(self.obses[start_index:], device=device)
            actions = torch.as_tensor(self.actions[start_index:], device=device)
            rewards1 = torch.as_tensor(self.rewards1[start_index:], device=device)
            rewards2 = torch.as_tensor(self.rewards2[start_index:], device=device)
            next_obses = torch.as_tensor(self.next_obses[start_index:], device=device)
            
            return obses, actions, rewards1, rewards2, next_obses
                
        else:
            obses = torch.as_tensor(self.obses[start_index:self.idx], device=device)
            actions = torch.as_tensor(self.actions[start_index:self.idx], device=device)
            rewards1 = torch.as_tensor(self.rewards1[start_index:self.idx], device=device)
            rewards2 = torch.as_tensor(self.rewards2[start_index:self.idx], device=device)
            next_obses = torch.as_tensor(self.next_obses[start_index:self.idx], device=device)
                
            return obses, actions, rewards1, rewards2, next_obses

    def sample(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size

        if self.recent_size == 0 or self.idx < self.recent_size: 
            idxs = np.random.randint(
                0, self.capacity if self.full else self.idx, size=self.batch_size 
            )
        else:
            idxs = np.random.randint(
                self.idx - self.recent_size, self.capacity if self.full else self.idx, size=self.batch_size 
            )


        obses = torch.as_tensor(self.obses[idxs], device=self.device)
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards1[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device)
        
        return obses, actions, rewards, next_obses

        
    def save(self, save_dir):
        if self.idx == self.last_save:
            return
        path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
        payload = [
            self.obses[self.last_save:self.idx],
            self.next_obses[self.last_save:self.idx],
            self.actions[self.last_save:self.idx],
            self.rewards[self.last_save:self.idx]
        ]
        self.last_save = self.idx
        torch.save(payload, path)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chucks:
            start, end = [int(x) for x in chunk.split('.')[0].split('_')]
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.obses[start:end] = payload[0]
            self.next_obses[start:end] = payload[1]
            self.actions[start:end] = payload[2]
            self.rewards[start:end] = payload[3]
            self.idx = end


    
def NashEquilibriumECOSSolver(M):
    """
    https://github.com/embotech/ecos-python
    min  c*x
    s.t. A*x = b
         G*x <= h
    https://github.com/embotech/ecos/wiki/Usage-from-MATLAB
    args: 
        c,b,h: numpy.array
        A, G: Scipy sparse matrix
    """
    row, col = M.shape
    c = np.zeros(row+1)
    # max z
    c[-1] = -1  
    
    # x1+x2+...+xn=1
    A = np.ones(row+1)
    A[-1] = 0.
    A = csr_matrix([A])
    b=np.array([1.])
    
    # M.T*x<=z
    G1 = np.ones((col, row+1))
    G1[:col, :row] = -1. * M.T
    # x>=0
    G2 = np.zeros((row, row+1))
    for i in range(row):
        G2[i, i]=-1. 
    # x<=1.
    G3 = np.zeros((row, row+1))
    for i in range(row):
        G3[i, i]=1. 
    G = csr_matrix(np.concatenate((G1, G2, G3)))
    h = np.concatenate((np.zeros(2*row), np.ones(row)))
    
    # specify number of variables
    dims={'l': col+2*row, 'q': []}
                       
    solution = ecos.solve(c,G,h,dims,A,b, verbose=False)

    p1_value = solution['x'][:row]
    p2_value = solution['z'][:col] # z is the dual variable of x
    # There are at least two bad cases with above constrained optimization,
    # where the constraints are not fully satisfied (some numerical issue):
    # 1. the sum of vars is larger than 1.
    # 2. the value of var may be negative.
    abs_p1_value = np.abs(p1_value)
    abs_p2_value = np.abs(p2_value)
    p1_value = abs_p1_value/np.sum(abs_p1_value)
    p2_value = abs_p2_value/np.sum(abs_p2_value)

    return (p1_value, p2_value)


class ReplayBufferSingle(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, num_actions, num_players, capacity, batch_size, device, recent_size=0):
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device
        self.num_actions = num_actions
        self.num_players = num_players

        self.obses = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.actions = np.empty((capacity, num_actions), dtype=np.int)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)

        self.recent_size = recent_size

        self.idx = 0
        self.last_save = 0
        self.full = False

    def add(self, obs, action, reward, next_obs):
        np.copyto(self.obses[self.idx], obs)
        aoh = np.zeros((self.num_actions ** self.num_players), dtype=np.int)
        action = np.array(action) + 1
        aoh[np.prod(action)-1] = 1
        np.copyto(self.actions[self.idx], aoh)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add_batch(self, obs, action, reward, next_obs, size):
        np.copyto(self.obses[self.idx:self.idx+size], obs)
        aoh = np.zeros((size,self.num_actions), dtype=np.int)
        #a = [action[n][0]*self.num_actions+action[n][1] for n in range(size)]
        aoh[np.arange(size), action] = 1
        np.copyto(self.actions[self.idx:self.idx+size], aoh)
        np.copyto(self.rewards[self.idx:self.idx+size], reward)
        np.copyto(self.next_obses[self.idx:self.idx+size], next_obs)

        self.idx = (self.idx + size) % self.capacity
        self.full = self.full or self.idx == 0

    def add_from_buffer(self, buf, h):
        obs, action, reward, next_obs = buf.get(h)
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def get_full(self, recent_size=0, device=None):

        if device is None:
            device = self.device

        if self.idx <= recent_size or recent_size == 0: 
            start_index = 0
        else:
            start_index = self.idx - recent_size

        if self.full:
            obses = torch.as_tensor(self.obses[start_index:], device=device)
            actions = torch.as_tensor(self.actions[start_index:], device=device)
            rewards = torch.as_tensor(self.rewards[start_index:], device=device)
            next_obses = torch.as_tensor(self.next_obses[start_index:], device=device)
            
            return obses, actions, rewards, next_obses
                
        else:
            obses = torch.as_tensor(self.obses[start_index:self.idx], device=device)
            actions = torch.as_tensor(self.actions[start_index:self.idx], device=device)
            rewards = torch.as_tensor(self.rewards[start_index:self.idx], device=device)
            next_obses = torch.as_tensor(self.next_obses[start_index:self.idx], device=device)
                
            return obses, actions, rewards, next_obses

    def sample(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size

        if self.recent_size == 0 or self.idx < self.recent_size: 
            idxs = np.random.randint(
                0, self.capacity if self.full else self.idx, size=self.batch_size 
            )
        else:
            idxs = np.random.randint(
                self.idx - self.recent_size, self.capacity if self.full else self.idx, size=self.batch_size 
            )


        obses = torch.as_tensor(self.obses[idxs], device=self.device)
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device)
        
        return obses, actions, rewards, next_obses

        
    def save(self, save_dir):
        if self.idx == self.last_save:
            return
        path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
        payload = [
            self.obses[self.last_save:self.idx],
            self.next_obses[self.last_save:self.idx],
            self.actions[self.last_save:self.idx],
            self.rewards[self.last_save:self.idx]
        ]
        self.last_save = self.idx
        torch.save(payload, path)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chucks:
            start, end = [int(x) for x in chunk.split('.')[0].split('_')]
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.obses[start:end] = payload[0]
            self.next_obses[start:end] = payload[1]
            self.actions[start:end] = payload[2]
            self.rewards[start:end] = payload[3]
            self.idx = end
