import numpy as np
import torch
from torch.nn import functional as F
import random

class Taxi_nocost:
    def __init__(self, ncol=5, nrow=5, gamma=0.9, cost_limit=0, num_traj=2000):
        self.ncol = ncol
        self.nrow = nrow
        self.grid_size = ncol * nrow
        self.cost_limit = cost_limit
        self.state_dim = 1
        self.action_dim = 1
        self.gamma = gamma
        self.action_size = 6
        self.num_traj = num_traj
        self.init_locate = [0, 20, 4, 23]
        
        self.destination_state = random.choice([0, 1, 2, 3])
        self.destination_locate = self.init_locate[self.destination_state]
        self.hole_states = [2, 14, 21]
        self.taxi_init_locate = self.init_locate[random.choice([0, 1, 2, 3])]
        self.passenger_states = [0, 1, 2, 3, 4] # 0: Red 1: Yellow 2: Green 3: Blue 4: in taxi
        self.passenger_state = random.choice([0, 1, 2, 3])
        # print(f'In this test, the destination is {self.destination_locate}, the init taxi state is {self.taxi_init_locate}, init passenger locate is {self.init_locate[self.passenger_state]}')
        
        
        self.test_time_step = 20
        self.collect_time_step = 30
        
        self.state_size = self.grid_size * 4 * len(self.passenger_states)
        
        self.step = self.createP()
        
    def createP(self):
        P = [[[] for j in range(self.action_size) ] for i in range(self.state_size)]
        change_taxi = [[0, -1], [1, 0], [0, 1], [-1, 0]]
        
        for state in range(self.state_size):
            location = state // (4 * 5 )
            p_d = state % (4 * 5)
            passenger = p_d // 4
            destination = p_d % 4
            destination_locate = self.init_locate[destination]
            i = location // self.ncol
            j = location % self.ncol
            for act in range(self.action_size):
                if location == destination_locate and passenger == 4 and act == 5: #终点放乘客
                    nxt_passenger_state = self.init_locate.index(location)
                    nxt_state = 20 * location + 4 * nxt_passenger_state + destination
                    P[state][act] = (nxt_state, 20, 0, True, False)
                    continue
                elif location != destination_locate and act == 5 and passenger == 4:  # 乘客在车上，不到终点放乘客，终止
                    P[state][act] = (state, -10, 0, True, False)
                    continue
                elif act == 5 and passenger < 4: # 乘客不在车上，非法放人，不终止
                    P[state][act] = (state, -10, 0, False, False)
                    continue
                elif act == 4 and passenger < 4 and location != self.init_locate[passenger]: # 乘客不在车上，非法接人，不终止
                    P[state][act] = (state, -10, 0, False, False)
                    continue
                elif act == 4 and passenger == 4: # 乘客已经在车上，非法接人，不终止
                    P[state][act] = (state, -10, 0, False, False)
                    continue
                elif act == 4 and passenger < 4 and location == self.init_locate[passenger]: # 合法接人
                    nxt_passenger = 4
                    nxt_state = 20 * location + 4 * nxt_passenger + destination
                    P[state][act] = (nxt_state, 0, 0, False, False)
                    continue
                
                next_i = min(self.nrow - 1, max(0, i + change_taxi[act][0]))
                next_j = min(self.ncol - 1, max(0, j + change_taxi[act][1]))
                next_loc = next_i * self.ncol + next_j
                reward = -1
                cost = 0
                done = False
                hole = False
                if next_loc in self.hole_states:
                    cost = 10
                    done = False
                    hole = True
                    nxt_state = 20 * next_loc + 4 * passenger + destination
                    P[state][act] = (nxt_state, reward, cost, done, hole)
                else:
                    nxt_state = 20 * next_loc + 4 * passenger + destination
                    P[state][act] = (nxt_state, reward, cost, done, hole)
        return P
        # for i in range(self.nrow):
        #     for j in range(self.ncol):
        #         for passenger in range(len(self.passenger_states)):
        #             for act in range(self.action_size):
        #                 location = i * self.ncol + j
        #                 if location == self.destination_locate and passenger == 4 and act == 5: #终点放乘客
        #                     nxt_passenger_state = self.init_locate.index(location)
        #                     P[location][passenger][act] = (location, nxt_passenger_state, 20, 0, True, False)
        #                     # print(f'location: {location}, passenger: {passenger}, act: {act}, reward: 20')
        #                     continue
        #                 elif location != self.destination_locate and act == 5 and passenger == 4:  # 乘客在车上，不到终点放乘客，终止
        #                     P[location][passenger][act] = (location, passenger, -10, 0, True, False)
        #                     continue
        #                 elif act == 5 and passenger < 4: # 乘客不在车上，非法放人，不终止
        #                     P[location][passenger][act] = (location, passenger, -10, 0, False, False)
        #                     continue
        #                 elif act == 4 and passenger < 4 and location != self.init_locate[passenger]: # 乘客不在车上，非法接人，不终止
        #                     P[location][passenger][act] = (location, passenger, -10, 0, False, False)
        #                     continue
        #                 elif act == 4 and passenger == 4: # 乘客已经在车上，非法接人，不终止
        #                     P[location][passenger][act] = (location, passenger, -10, 0, False, False)
        #                     continue
        #                 elif act == 4 and passenger < 4 and location == self.init_locate[passenger]: # 合法接人
        #                     P[location][passenger][act] = (location, 4, 0, 0, False, False)
        #                     continue
                        
        #                 next_i = min(self.nrow - 1, max(0, i + change_taxi[act][0]))
        #                 next_j = min(self.ncol - 1, max(0, j + change_taxi[act][1]))
        #                 next_loc = next_i * self.ncol + next_j
        #                 reward = -1
        #                 cost = 0
        #                 done = False
        #                 hole = False
        #                 if next_loc in self.hole_states:
        #                     cost = 10
        #                     done = False
        #                     hole = True
        #                     P[location][passenger][act] = (next_loc, passenger, reward, cost, done, hole)
        #                 else:
        #                     P[location][passenger][act] = (next_loc, passenger, reward, cost, done, hole)
        # # print(f'In this test, {P[self.destination_locate][4][5]}')
        # return P
    
    def reset(self):
        self.destination_state = random.choice([0, 1, 2, 3])
        self.passenger_state = random.choice([0, 1, 2, 3])
        self.taxi_init_locate = self.init_locate[random.choice([0, 1, 2, 3])]
        self.init_state = 20 * self.taxi_init_locate + 4 * self.passenger_state + self.destination_state
        return self.init_state
    
    def state_action_onehot_encode(self):
        state = torch.tensor(range(self.state_size), dtype=torch.float32)
        action = torch.tensor(range(self.action_size), dtype=torch.float32)
        state_one_hot = F.one_hot(state.to(torch.int64), num_classes=self.state_size).to(torch.float32)
        action_one_hot = F.one_hot(action.to(torch.int64), num_classes=self.action_size).to(torch.float32)
        obs_encode = state_one_hot.repeat_interleave(self.action_size, dim=0)
        acts_encode = action_one_hot.repeat(self.state_size, 1)
        return obs_encode, acts_encode
    
    # def plot_policy(self, policy, random=True):
    #     # action_meaning = ['<', 'v', '>', '^']
    #     for i in range(self.nrow):
    #         for j in range(self.ncol):
    #             state = i*self.ncol+j
    #             pi = policy[state,:]
    #             if random:
    #                 if pi.sum() == 0:
    #                     action = np.random.choice(range(self.action_size), size=1)[0]
    #                 else:
    #                     action = np.random.choice(range(self.action_size), size=1, p=pi)[0]
    #             else:
    #                 action = np.argmax(pi)
    #             if state in self.hole_state:
    #                 print('H', end=' ')
    #                 #print(action_meaning[action], end=' ')
    #             elif state == self.goal_state:
    #                 print('G', end=' ')
    #             else:
    #                 print(action_meaning[action], end=' ')
    #         print()  
    