import numpy as np
import gym
from gym.spaces import MultiBinary, Discrete, Box
import scipy.linalg
import math

from solver.eq_LPsolver import CoarseCorrelatedEquilibriumLPSolver

'''
fast sampling. credit: https://stackoverflow.com/questions/34187130/fast-random-weighted-selection-across-all-rows-of-a-stochastic-matrix/34190035
'''
def sample(prob_matrix, items, n):

    cdf = np.cumsum(prob_matrix, axis=1)
    # random numbers are expensive, so we'll get all of them at once
    ridx = np.random.random(size=n)
    # the one loop we can't avoid, made as simple as possible
    idx = np.zeros(n, dtype=int)
    for i, r in enumerate(ridx):
        idx[i] = np.searchsorted(cdf[i], r)
    # fancy indexing all at once is faster than indexing in a loop
    return items[idx]


class BlockGameGenSum(gym.Env):
    """A (stochastic) combination lock environment.
    
    Can configure the length, dimension, and switching probability via env_config"""

    def __init__(self,env_config={}):
        self.initialized=False

    def init(self,horizon=10, action_dim=3, num_players = 2, state_dim = 3, noise=0.1, num_envs = 50, verbose=False, rseed=12):
        self.initialized=True
        self.max_reward=1
        self.horizon=horizon
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_space = Discrete(self.action_dim)
        self.num_envs = num_envs

        self.total_action_dim = action_dim ** num_players

        self.reward_range = [-1,1]

        self.observation_dim = 2 ** int(math.ceil(np.log2(self.horizon+self.state_dim+1)))
        self.observation_space = Box(low=0.0, high=1.0, shape=(self.observation_dim,),dtype=np.float)

        self.noise = noise
        self.rotation = scipy.linalg.hadamard(self.observation_space.shape[0])

        self.seed(rseed)
        np.random.seed(rseed)

        self.trans_prob_matrices, self.reward_matrices1, self.reward_matrices2 = self.generate_random_trans_and_rewards()

        #print(self.trans_prob_matrices)
        #print(self.reward_matrices1)
        #print(self.reward_matrices2)
        
        nash_v1, nash_v2, _, _, nash_strategies = self.NEsolver()
        #print(nash_v1)
        #print(nash_strategies, np.array(nash_strategies).shape)
        # np.save('../../../data/nash_dqn_test/oracle_nash.npy', nash_strategies)
        if verbose:
            #print(self.trans_prob_matrices)
            print('oracle nash v1 star: ', np.mean(nash_v1[0], axis=0))  # the average nash value for initial states from max-player's view
            print('oracle nash v2 star: ', np.mean(nash_v2[0], axis=0))  # the average nash value for initial states from max-player's view
            
    def get_nash_strategy(self):
        nash_v1, nash_v2, _, _, nash_strategies = self.NEsolver()
        return nash_v1, nash_v2, np.array(nash_strategies)

    def generate_random_trans_and_rewards(self, SameRewardForNextState=False):
        """Generate arbitrary transition matrix and reward matrix.
        :param SameRewardForNextState: r(s,a) if True else r(s,a,s')
        :type SameRewardForNextState: bool
        :return: the list of transition matrix and the list of reward matrix, 
        both in shape: (dim_transition, dim_state, dim_action, dim_state)
        :rtype: [type]
        """
        trans_prob_matrices = []
        reward_matrices1 = []
        reward_matrices2 = []

        for _ in range(self.horizon):
            trans_prob_matrix = []
            reward_matrix1 = []
            reward_matrix2 = []
            for s in range(self.state_dim):
                trans_prob_matrix_for_s = []
                reward_matrix1_for_s = []
                reward_matrix2_for_s = []
                for a in range(self.total_action_dim):
                    #rands = np.random.uniform(0,1, self.state_dim) / 0.1
                    #rand_probs = np.exp(rands)/np.sum(np.exp(rands))
                    #rands = np.zeros(self.state_dim)
                    #index = np.random.randint(0,self.state_dim)
                    #rands[index] = 1
                    # if np.random.rand() < 0.3:
                    #     rands = np.random.uniform(0,0.5, self.state_dim) 
                    # elif np.random.rand() < 0.6:
                    #     rands = np.random.uniform(-0.5,0.5, self.state_dim) 
                    # else:
                    rands = np.random.uniform(0,1, self.state_dim)
                    #exp_rands = np.exp(rands / 0.2)
                    #rand_probs = exp_rands/np.sum(exp_rands)
                    rand_probs = rands / np.sum(rands)
                    #print(rand_probs)
                    trans_prob_matrix_for_s.append(rand_probs)
                    if SameRewardForNextState:  # r(s,a) this reduces stochasticity in nash value estimation thus work!
                        rs = int(self.state_dim) * [np.random.uniform(*self.reward_range)]
                        reward_matrix1_for_s.append(rs)
                        rs = int(self.state_dim) * [np.random.uniform(*self.reward_range)]
                        reward_matrix2_for_s.append(rs)
                    else:  # r(s,a,s')
                        # if np.random.rand() < 0.3:
                        #     rands = np.random.uniform(0,0.5, self.state_dim) 
                        # elif np.random.rand() < 0.6:
                        #     rands = np.random.uniform(-0.5,0.5, self.state_dim) 
                        # else:
                        #     rands = np.random.uniform(0.5,1, self.state_dim)
                        rands = np.random.uniform(*self.reward_range, self.state_dim)
                        reward_matrix1_for_s.append(list(rands))
                        # if np.random.rand() < 0.3:
                        #     rands = np.random.uniform(0,0.5, self.state_dim) 
                        # elif np.random.rand() < 0.6:
                        #     rands = np.random.uniform(-0.5,0.5, self.state_dim) 
                        # else:
                        #     rands = np.random.uniform(0.5,1, self.state_dim)
                        rands = np.random.uniform(*self.reward_range, self.state_dim)
                        reward_matrix2_for_s.append(list(rands))
 
                trans_prob_matrix.append(trans_prob_matrix_for_s)
                reward_matrix1.append(reward_matrix1_for_s)
                reward_matrix2.append(reward_matrix2_for_s)
            trans_prob_matrices.append(trans_prob_matrix)
            reward_matrices1.append(reward_matrix1)
            reward_matrices2.append(reward_matrix2)

        return np.array(trans_prob_matrices), np.array(reward_matrices1), np.array(reward_matrices2)

    
    def step(self,a,s=None,h=None):
        if s is not None:
            self.h = h
            self.state = s
        a = [a[n][0]*self.action_dim+a[n][1] for n in range(self.num_envs)]
        #a = a[0]*self.action_dim+a[1]
        #print(a)

        #print(self.trans_prob_matrices[self.h])
        
        trans_prob = self.trans_prob_matrices[self.h][self.state][np.arange(self.num_envs),a]
        #print("trans_prob")
        #print(trans_prob)
        next_states = sample(trans_prob, np.arange(self.state_dim), self.num_envs)
        #print(next_states)
        rewards1 = self.reward_matrices1[self.h][self.state][np.arange(self.num_envs),a][np.arange(self.num_envs),next_states].reshape(-1,1)
        rewards2 = self.reward_matrices2[self.h][self.state][np.arange(self.num_envs),a][np.arange(self.num_envs),next_states].reshape(-1,1)

        self.state = next_states
        self.h += 1

        obs = self.make_obs(self.state, self.h)
        
        done = self.h == self.horizon

        return obs, rewards1, rewards2, done, {}

    def get_state(self):
        return self.state

    def make_obs(self, s, h):

        gaussian = np.zeros((self.num_envs, self.observation_space.shape[0]))
        gaussian[:,:(self.horizon+self.state_dim)] = np.random.normal(0,self.noise,[self.num_envs,self.horizon+self.state_dim])
        gaussian[np.arange(self.num_envs), s] += 1
        gaussian[np.arange(self.num_envs),self.state_dim+h] += 1

        x = np.matmul(self.rotation, gaussian.T).T

        return x

    def trim_observation(self,o,h):
        return (o)

    def reset(self):
        self.state = np.random.randint(0, self.state_dim, self.num_envs)
        self.h = 0
        obs = self.make_obs(self.state, self.h)

        return obs

    def render(self,mode='human'):
        if self.state == 0:
            print("A%d" % (self.h))
        if self.state == 1:
            print("B%d" % (self.h))
        if self.state == 2:
            print("C%d" % (self.h))
        

    def close(self):
        pass

    def NEsolver(self, verbose = False):
        """
        Formulas for calculating Nash equilibrium strategies and values:
        1. Nash strategies: (\pi_a^*, \pi_b^*) = \min \max Q(s,a,b), 
            where Q(s,a,b) = r(s,a,b) + \gamma \min \max Q(s',a',b') (this is the definition of Nash Q-value);
        2. Nash value: Nash V(s) = \min \max Q(s,a,b) = \pi_a^* Q(s,a,b) \pi_b^{*T}
        """

        self.Nash_v1 = []
        self.Nash_v2 = []
        self.Nash_q1 = []
        self.Nash_q2 = []
        self.Nash_strategies = []
        for tm, rm1, rm2 in zip(self.trans_prob_matrices[::-1], self.reward_matrices1[::-1], self.reward_matrices2[::-1]): # inverse enumerate 
            if len(self.Nash_v1) > 0:
                rm1 = np.array(rm1)+np.array(self.Nash_v1[-1])  # broadcast sum on rm's last dim, last one in Nash_v is for the next state
                rm2 = np.array(rm2)+np.array(self.Nash_v2[-1])  # broadcast sum on rm's last dim, last one in Nash_v is for the next state
            nash_q_values1 = np.einsum("ijk,ijk->ij", tm, rm1)  # transition prob * reward for the last dimension in (state, action, next_state)
            nash_q_values1 = nash_q_values1.reshape(-1, self.action_dim, self.action_dim) # action list to matrix
            self.Nash_q1.append(nash_q_values1)
            nash_q_values2 = np.einsum("ijk,ijk->ij", tm, rm2) 
            nash_q_values2 = nash_q_values2.reshape(-1, self.action_dim, self.action_dim) 
            self.Nash_q2.append(nash_q_values2)
            ne_values1 = []
            ne_values2 = []
            ne_strategies = []
            for nash_q_value1, nash_q_value2 in zip(nash_q_values1, nash_q_values2):
                _, _, ne, v1, v2 = CoarseCorrelatedEquilibriumLPSolver(nash_q_value1, nash_q_value2)
                ne_strategies.append(ne)
                ne_values1.append(v1)
                ne_values2.append(v2)
                #ne_value = ne[0]@nash_q_value@ne[1].T
                #ne_values.append(ne_value)  # each value is a Nash equilibrium value on one state
            print(ne_values1)
            self.Nash_v1.append(ne_values1)  # (trans, state)
            self.Nash_v2.append(ne_values2)  
            self.Nash_strategies.append(ne_strategies)
        self.Nash_v1 = self.Nash_v1[::-1]
        self.Nash_v2 = self.Nash_v2[::-1]
        self.Nash_q1 = self.Nash_q1[::-1]  # (dim_transition, dim_state, dim_action, dim_action)
        self.Nash_q2 = self.Nash_q2[::-1]
        self.Nash_strategies = self.Nash_strategies[::-1]  # (dim_transition, dim_state, #players, dim_action)
        # if verbose:
        #     print('Nash values of all states (from start to end): \n', self.Nash_v)
        #     print('Nash Q-values of all states (from start to end): \n', self.Nash_q)
        #     print('Nash strategies of all states (from start to end): \n', self.Nash_strategies)

        ## To evaluate the correctness of the above values
        # for v, q, s in zip(self.Nash_v, self.Nash_q, self.Nash_strategies):
        #     for vv,qq,ss in zip(v,q,s):
        #         cal_v = ss[0]@qq@ss[1].T
        #         print(vv, cal_v)

        return self.Nash_v1, self.Nash_v2, self.Nash_q1, self.Nash_q2, self.Nash_strategies

    #def seed(self, seed=None):
    #    gym.spaces.prng.seed(seed)

if __name__=='__main__':
    env = BlockGame()
    env.init(num_envs=2)

    nash_v, _, nash_strategies = env.NEsolver()
    print(nash_strategies, np.array(nash_strategies).shape)
    # np.save('../../../data/nash_dqn_test/oracle_nash.npy', nash_strategies)
    print('oracle nash v star: ', np.mean(nash_v[0], axis=0))  # the average nash value for initial states from max-player's view
    
    # obs = env.reset()
    # print(obs)
    # done = False
    # while not done:
    #     obs, r, done, _ = env.step([[1,0],[0,1]])
    #     print(obs, r, done)
    # for t in range(20):
    #     o = env.reset()
    #     done = False
    #     h = 0
    #     while not done:
    #         env.render()
    #         #print(env.trim_observation(o,h))
    #         (o,r,done,blah) = env.step(env.action_space.sample())
    #         print(o)
    #         h += 1
    #     print("End of episode: r=%d" % (r))
