import numpy as np
import gym
from gym.spaces import MultiBinary, Discrete, Box
import scipy.linalg
import math

import ecos
from scipy.sparse import csr_matrix

#from utils import NashEquilibriumECOSSolver
def NashEquilibriumECOSSolver(M):
    """
    https://github.com/embotech/ecos-python
    min  c*x
    s.t. A*x = b
         G*x <= h
    https://github.com/embotech/ecos/wiki/Usage-from-MATLAB
    args: 
        c,b,h: numpy.array
        A, G: Scipy sparse matrix
    """
    row, col = M.shape
    c = np.zeros(row+1)
    # max z
    c[-1] = -1  
    
    # x1+x2+...+xn=1
    A = np.ones(row+1)
    A[-1] = 0.
    A = csr_matrix([A])
    b=np.array([1.])
    
    # M.T*x<=z
    G1 = np.ones((col, row+1))
    G1[:col, :row] = -1. * M.T
    # x>=0
    G2 = np.zeros((row, row+1))
    for i in range(row):
        G2[i, i]=-1. 
    # x<=1.
    G3 = np.zeros((row, row+1))
    for i in range(row):
        G3[i, i]=1. 
    G = csr_matrix(np.concatenate((G1, G2, G3)))
    h = np.concatenate((np.zeros(2*row), np.ones(row)))
    
    # specify number of variables
    dims={'l': col+2*row, 'q': []}
                       
    solution = ecos.solve(c,G,h,dims,A,b, verbose=False)

    p1_value = solution['x'][:row]
    p2_value = solution['z'][:col] # z is the dual variable of x
    # There are at least two bad cases with above constrained optimization,
    # where the constraints are not fully satisfied (some numerical issue):
    # 1. the sum of vars is larger than 1.
    # 2. the value of var may be negative.
    abs_p1_value = np.abs(p1_value)
    abs_p2_value = np.abs(p2_value)
    p1_value = abs_p1_value/np.sum(abs_p1_value)
    p2_value = abs_p2_value/np.sum(abs_p2_value)

    return (p1_value, p2_value)

'''
fast sampling. credit: https://stackoverflow.com/questions/34187130/fast-random-weighted-selection-across-all-rows-of-a-stochastic-matrix/34190035
'''
def sample(prob_matrix, items, n):

    cdf = np.cumsum(prob_matrix, axis=1)
    # random numbers are expensive, so we'll get all of them at once
    ridx = np.random.random(size=n)
    # the one loop we can't avoid, made as simple as possible
    idx = np.zeros(n, dtype=int)
    for i, r in enumerate(ridx):
        idx[i] = np.searchsorted(cdf[i], r)
    # fancy indexing all at once is faster than indexing in a loop
    idx = np.minimum(idx, len(items)-1)
    return items[idx]


class BlockGame(gym.Env):
    """A (stochastic) combination lock environment.
    
    Can configure the length, dimension, and switching probability via env_config"""

    def __init__(self,env_config={}):
        self.initialized=False

    def init(self,horizon=10, action_dim=3, num_players = 2, state_dim = 3, noise=0.1, num_envs = 50, verbose=False, rseed=12):
        self.initialized=True
        self.max_reward=1
        self.horizon=horizon
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_space = Discrete(self.action_dim)
        self.num_envs = num_envs

        self.total_action_dim = action_dim ** num_players

        self.reward_range = [-1,1]

        self.observation_dim = 2 ** int(math.ceil(np.log2(self.horizon+self.state_dim+1)))
        self.observation_space = Box(low=0.0, high=1.0, shape=(self.observation_dim,),dtype=np.float)

        self.noise = noise
        self.rotation = scipy.linalg.hadamard(self.observation_space.shape[0])

        self.seed(rseed)
        np.random.seed(rseed)

        self.trans_prob_matrices, self.reward_matrices = self.generate_random_trans_and_rewards()
        
        nash_v, _, nash_strategies = self.NEsolver()
        #print(nash_strategies, np.array(nash_strategies).shape)
        # np.save('../../../data/nash_dqn_test/oracle_nash.npy', nash_strategies)
        if verbose:
            #print(self.trans_prob_matrices)
            print('oracle nash v star: ', np.mean(nash_v[0], axis=0))  # the average nash value for initial states from max-player's view
            
    def get_nash_strategy(self):
        nash_v, _, nash_strategies = self.NEsolver()
        return nash_v, np.array(nash_strategies)

    def generate_random_trans_and_rewards(self, SameRewardForNextState=False):
        """Generate arbitrary transition matrix and reward matrix.
        :param SameRewardForNextState: r(s,a) if True else r(s,a,s')
        :type SameRewardForNextState: bool
        :return: the list of transition matrix and the list of reward matrix, 
        both in shape: (dim_transition, dim_state, dim_action, dim_state)
        :rtype: [type]
        """
        trans_prob_matrices = []
        reward_matrices = []
        for _ in range(self.horizon):
            trans_prob_matrix = []
            reward_matrix = []
            for s in range(self.state_dim):
                trans_prob_matrix_for_s = []
                reward_matrix_for_s = []
                for a in range(self.total_action_dim):
                    #rands = np.random.uniform(0,1, self.state_dim) / 0.1
                    #rand_probs = np.exp(rands)/np.sum(np.exp(rands))
                    #rands = np.zeros(self.state_dim)
                    #index = np.random.randint(0,self.state_dim)
                    #rands[index] = 1
                    rands = np.random.uniform(0,1, self.state_dim) 
                    rand_probs = rands/np.sum(rands)
                    trans_prob_matrix_for_s.append(rand_probs)
                    if SameRewardForNextState:  # r(s,a) this reduces stochasticity in nash value estimation thus work!
                        rs = int(self.state_dim) * [np.random.uniform(*self.reward_range)]
                        reward_matrix_for_s.append(rs)
                    else:  # r(s,a,s')
                        rs = np.random.uniform(*self.reward_range, self.state_dim)
                        reward_matrix_for_s.append(list(rs))
 
                trans_prob_matrix.append(trans_prob_matrix_for_s)
                reward_matrix.append(reward_matrix_for_s)
            trans_prob_matrices.append(trans_prob_matrix)
            reward_matrices.append(reward_matrix)

        return np.array(trans_prob_matrices), np.array(reward_matrices)

    
    def step(self,a,s=None,h=None):
        if s is not None:
            self.h = h
            self.state = s
        a = [a[n][0]*self.action_dim+a[n][1] for n in range(self.num_envs)]
        #a = a[0]*self.action_dim+a[1]
        #print(a)

        #print(self.trans_prob_matrices[self.h])
        
        trans_prob = self.trans_prob_matrices[self.h][self.state][np.arange(self.num_envs),a]
        #print("trans_prob")
        #print(trans_prob)
        next_states = sample(trans_prob, np.arange(self.state_dim), self.num_envs)
        #print(next_states)
        rewards = self.reward_matrices[self.h][self.state][np.arange(self.num_envs),a][np.arange(self.num_envs),next_states].reshape(-1,1)

        self.state = next_states
        self.h += 1

        obs = self.make_obs(self.state, self.h)
        
        done = self.h == self.horizon

        return obs, rewards, done, {}
    
    
    # def step(self,a,s=None,h=None):
    #     if s is not None:
    #         self.h = h
    #         self.state = s
    #     #a = [a[n][0]*self.action_dim+a[n][1] for n in range(self.num_envs)]
    #     #a = a[0]*self.action_dim+a[1]
    #     #print(a)

    #     #print(self.trans_prob_matrices[self.h])
    #     next_states = np.zeros(self.num_envs, dtype=int)
    #     rewards = np.zeros([self.num_envs,1])
    #     for n in range(self.num_envs):
    #         an = a[n][0]*self.action_dim+a[n][1]

    #         if an == 0:
    #             next_states[n] = np.maximum(self.state[n] -1,0) 
    #         elif an == 1:
    #             next_states[n] = np.minimum(self.state[n] + 1, self.state_dim-1) 
    #         else: 
    #             next_states[n] = self.state[n]
    #         #next_states[n] = np.argmax(self.trans_prob_matrices[self.h][self.state[n]][an])
    #         rewards = self.reward_matrices[self.h][self.state[n]][an][next_states[n]]

    #     self.state = next_states
    #     self.h += 1

    #     obs = self.make_obs(self.state, self.h)
        
    #     done = self.h == self.horizon

    #     return obs, rewards, done, {}

    def get_state(self):
        return self.state

    def make_obs(self, s, h):

        gaussian = np.zeros((self.num_envs, self.observation_space.shape[0]))
        gaussian[:,:(self.horizon+self.state_dim)] = np.random.normal(0,self.noise,[self.num_envs,self.horizon+self.state_dim])
        gaussian[np.arange(self.num_envs), s] += 1
        gaussian[np.arange(self.num_envs),self.state_dim+h] += 1

        x = np.matmul(self.rotation, gaussian.T).T

        return x

    def trim_observation(self,o,h):
        return (o)

    def reset(self):
        self.state = np.random.randint(0, self.state_dim, self.num_envs)
        self.h = 0
        obs = self.make_obs(self.state, self.h)

        return obs

    def render(self,mode='human'):
        if self.state == 0:
            print("A%d" % (self.h))
        if self.state == 1:
            print("B%d" % (self.h))
        if self.state == 2:
            print("C%d" % (self.h))
        

    def close(self):
        pass

    def NEsolver(self, verbose = False):
        """
        Formulas for calculating Nash equilibrium strategies and values:
        1. Nash strategies: (\pi_a^*, \pi_b^*) = \min \max Q(s,a,b), 
            where Q(s,a,b) = r(s,a,b) + \gamma \min \max Q(s',a',b') (this is the definition of Nash Q-value);
        2. Nash value: Nash V(s) = \min \max Q(s,a,b) = \pi_a^* Q(s,a,b) \pi_b^{*T}
        """

        self.Nash_v = []
        self.Nash_q = []
        self.Nash_strategies = []
        for tm, rm in zip(self.trans_prob_matrices[::-1], self.reward_matrices[::-1]): # inverse enumerate 
            if len(self.Nash_v) > 0:
                rm = np.array(rm)+np.array(self.Nash_v[-1])  # broadcast sum on rm's last dim, last one in Nash_v is for the next state
            nash_q_values = np.einsum("ijk,ijk->ij", tm, rm)  # transition prob * reward for the last dimension in (state, action, next_state)
            nash_q_values = nash_q_values.reshape(-1, self.action_dim, self.action_dim) # action list to matrix
            self.Nash_q.append(nash_q_values)
            ne_values = []
            ne_strategies = []
            for nash_q_value in nash_q_values:
                ne = NashEquilibriumECOSSolver(nash_q_value)
                ne_strategies.append(ne)
                ne_value = ne[0]@nash_q_value@ne[1].T
                ne_values.append(ne_value)  # each value is a Nash equilibrium value on one state
            self.Nash_v.append(ne_values)  # (trans, state)
            self.Nash_strategies.append(ne_strategies)
        self.Nash_v = self.Nash_v[::-1]
        self.Nash_q = self.Nash_q[::-1]  # (dim_transition, dim_state, dim_action, dim_action)
        self.Nash_strategies = self.Nash_strategies[::-1]  # (dim_transition, dim_state, #players, dim_action)
        if verbose:
            print('Nash values of all states (from start to end): \n', self.Nash_v)
            print('Nash Q-values of all states (from start to end): \n', self.Nash_q)
            print('Nash strategies of all states (from start to end): \n', self.Nash_strategies)

        ## To evaluate the correctness of the above values
        # for v, q, s in zip(self.Nash_v, self.Nash_q, self.Nash_strategies):
        #     for vv,qq,ss in zip(v,q,s):
        #         cal_v = ss[0]@qq@ss[1].T
        #         print(vv, cal_v)

        return self.Nash_v, self.Nash_q, self.Nash_strategies

    #def seed(self, seed=None):
    #    gym.spaces.prng.seed(seed)

if __name__=='__main__':
    env = BlockGame()
    env.init(num_envs=2)

    nash_v, _, nash_strategies = env.NEsolver()
    print(nash_strategies, np.array(nash_strategies).shape)
    # np.save('../../../data/nash_dqn_test/oracle_nash.npy', nash_strategies)
    print('oracle nash v star: ', np.mean(nash_v[0], axis=0))  # the average nash value for initial states from max-player's view
    
    # obs = env.reset()
    # print(obs)
    # done = False
    # while not done:
    #     obs, r, done, _ = env.step([[1,0],[0,1]])
    #     print(obs, r, done)
    # for t in range(20):
    #     o = env.reset()
    #     done = False
    #     h = 0
    #     while not done:
    #         env.render()
    #         #print(env.trim_observation(o,h))
    #         (o,r,done,blah) = env.step(env.action_space.sample())
    #         print(o)
    #         h += 1
    #     print("End of episode: r=%d" % (r))
