import numpy as np
import torch
from torch.nn import functional as F

class FrozenLakeEnv_nocost:
    def __init__(self, ncol=3, nrow=3, gamma=0.9, cost_limit=0, num_traj=2000):
        self.ncol = ncol 
        self.nrow = nrow 
        self.state_size = ncol * nrow
        self.cost_limit = cost_limit
        self.state_dim = 1
        self.action_dim = 1
        self.action_size = 4
        self.gamma = gamma
        self.num_traj = num_traj
        
        if nrow == 3:
            self.goal_state = 8
            self.hole_state = [4]
            self.collect_time_step = 10
            self.test_time_step = 10
        elif nrow == 4:
            self.goal_state = 15
            self.hole_state = [5, 7, 11, 12]
            self.collect_time_step = 15
            self.test_time_step = 15
        elif nrow == 8:
            self.goal_state = 63
            self.hole_state = [3, 11, 19, 29, 35, 41, 42, 46, 49, 52, 54, 59]
            self.collect_time_step = 50
            self.test_time_step = 25
        elif nrow == 12:
            self.goal_state = 143
            self.hole_state = [3, 7, 13, 20, 27, 30, 45, 47, 48, 50, 55, 72, 73, 77, 80, 93, 95, 99, 100, 109, 117, 119, 125]
            self.collect_time_step = 40
            self.test_time_step = 40
        
        self.step = self.createP()
        

    def createP(self):

        P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]

        change = [[0, -1], [1, 0], [0, 1], [-1, 0]]
        
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(self.action_size):
                    state = i * self.ncol + j
                    if state == self.goal_state:
                        P[i * self.ncol + j][a] = (i * self.ncol + j, 0, 0, True, False)
                        continue
                    next_i = min(self.nrow - 1, max(0, i + change[a][0]))
                    next_j = min(self.ncol - 1, max(0, j + change[a][1]))
                    next_state = next_i * self.ncol + next_j
                    reward = 0
                    cost = 0
                    hole = False
                    done = False
                    if next_state in self.hole_state:
                        hole = True
                        done = True
                        P[state][a] = (next_state, reward, cost, done, hole)
                    elif next_state == self.goal_state:
                        reward = 1
                        done = True
                        P[state][a] = (next_state, reward, cost, done, hole)
                    else:
                        P[state][a] = (next_state, reward, cost, done, hole)
        return P

    
    def reset(self):
        return 0



class FrozenLakeEnv:
    def __init__(self, ncol=3, nrow=3, gamma=0.9, cost_limit=0, num_traj=2000):
        self.ncol = ncol  
        self.nrow = nrow  
        self.state_size = ncol * nrow
        self.cost_limit = cost_limit
        self.state_dim = 1
        self.action_dim = 1
        self.action_size = 4
        self.gamma = gamma
        self.num_traj = num_traj
        
        if nrow == 3:
            self.goal_state = 8
            self.hole_state = [4]
            self.collect_time_step = 10
            self.test_time_step = 10
        elif nrow == 4:
            self.goal_state = 15
            self.hole_state = [5, 7, 11, 12]
            self.collect_time_step = 15
            self.test_time_step = 15
        elif nrow == 8:
            self.goal_state = 63
            self.hole_state = [3, 11, 19, 29, 35, 41, 42, 46, 49, 52, 54, 59]
            self.collect_time_step = 60
            self.test_time_step = 25
        elif nrow == 12:
            self.goal_state = 143
            self.hole_state = [3, 7, 13, 20, 27, 30, 45, 47, 48, 50, 55, 72, 73, 77, 80, 93, 95, 99, 100, 109, 117, 119, 125]
            self.collect_time_step = 40
            self.test_time_step = 40
        
        self.step = self.createP()
        

    def createP(self):
        P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]
        change = [[0, -1], [1, 0], [0, 1], [-1, 0]]
        
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(self.action_size):
                    state = i * self.ncol + j
                    if state == self.goal_state:
                        P[i * self.ncol + j][a] = (i * self.ncol + j, 0, 0, True, False)
                        continue
                    next_i = min(self.nrow - 1, max(0, i + change[a][0]))
                    next_j = min(self.ncol - 1, max(0, j + change[a][1]))
                    next_state = next_i * self.ncol + next_j
                    reward = 0
                    cost = 0
                    hole = False
                    done = False
                    if next_state in self.hole_state:
                        cost = 0.5
                        reward = 0
                        hole = True
                        done = False
                        P[state][a] = (next_state, reward, cost, done, hole)
                    elif next_state == self.goal_state:
                        reward = 1
                        done = True
                        P[state][a] = (next_state, reward, cost, done, hole)
                    else:
                        P[state][a] = (next_state, reward, cost, done, hole)
        return P

    
    def reset(self):
        return 0
    
    def plot_policy(self, policy, random=True):
        action_meaning = ['<', 'v', '>', '^']
        for i in range(self.nrow):
            for j in range(self.ncol):
                state = i*self.ncol+j
                pi = policy[state,:]
                if random:
                    if pi.sum() == 0:
                        action = np.random.choice(range(self.action_size), size=1)[0]
                    else:
                        action = np.random.choice(range(self.action_size), size=1, p=pi)[0]
                else:
                    action = np.argmax(pi)
                if state in self.hole_state:
                    print('H', end=' ')
                    #print(action_meaning[action], end=' ')
                elif state == self.goal_state:
                    print('G', end=' ')
                else:
                    print(action_meaning[action], end=' ')
            print()  

    def state_action_onehot_encode(self):
        state = torch.tensor(range(self.state_size), dtype=torch.float32)
        action = torch.tensor(range(self.action_size), dtype=torch.float32)
        state_one_hot = F.one_hot(state.to(torch.int64), num_classes=self.state_size).to(torch.float32)
        action_one_hot = F.one_hot(action.to(torch.int64), num_classes=self.action_size).to(torch.float32)
        obs_encode = state_one_hot.repeat_interleave(self.action_size, dim=0)
        acts_encode = action_one_hot.repeat(self.state_size, 1)
        return obs_encode, acts_encode
        
    
    def getdataset(self, expert_pi, percent=1, alg='Importance_Sampling'):
        self.expert_pi = expert_pi
        offline_dataset = {
            'observation': [],
            'action': [],
            'reward': [],
            'cost': [],
            'new_observation': [],
            'hole': [],
            'goal': [], 
            'is_init': []
        }
        mu_0 = np.zeros((self.state_size))
        mu_0[0] = 1
        mu_D_count = np.zeros(self.state_size*self.action_size)
        r_s_a = np.zeros(self.state_size*self.action_size)
        c_s_a = np.zeros(self.state_size*self.action_size)
        
        M = np.zeros((self.state_size, self.state_size*self.action_size))
        P = np.zeros((self.state_size, self.state_size*self.action_size))
        goal_count = 0
        done_count = 0

        for _ in range(self.num_traj):
            observation = self.reset()
            done = False
            hole = False
            goal = False
            time_step = 0
            while not done and time_step < self.collect_time_step:
                if np.random.rand() < percent:
                    action = np.argmax(expert_pi[observation])
                else:
                    action = np.random.choice(range(self.action_size), size=1, p=[1/self.action_size for _ in range(self.action_size)])[0]
                
                new_observation, reward, cost, done, hole = self.step[observation][action]
                s_a = observation*self.action_size + action
                mu_D_count[s_a] += 1
                P[new_observation, s_a] += 1
                r_s_a[s_a] = reward
                c_s_a[s_a] = cost
                if done and not hole:
                    goal_count += 1
                if observation == 0:
                    is_init = 1
                else:
                    is_init = 0
                
                if alg == 'Importance_Sampling':
                    offline_dataset['observation'].append(observation)
                    offline_dataset['action'].append(action)
                    offline_dataset['new_observation'].append(new_observation)
                    offline_dataset['reward'].append(reward)
                    offline_dataset['cost'].append(cost)
                    offline_dataset['hole'].append(hole)
                    offline_dataset['goal'].append(goal)
                    offline_dataset['is_init'].append(is_init)
                elif alg == 'BC-Safe':
                    if not hole:
                        obs_lst, act_lst, nxt_obs_lst = [0.0] * self.state_size, [0.0] * self.action_size, [0.0] * self.state_size
                        obs_lst[observation], act_lst[action], nxt_obs_lst[new_observation] = 1, 1, 1
                        offline_dataset['observation'].append(obs_lst)
                        offline_dataset['action'].append(act_lst)
                        offline_dataset['new_observation'].append(nxt_obs_lst)
                        offline_dataset['reward'].append(reward)
                        offline_dataset['cost'].append(cost)
                        offline_dataset['hole'].append(hole)
                        offline_dataset['goal'].append(goal)
                        offline_dataset['is_init'].append(is_init)
                else:
                    obs_lst, act_lst, nxt_obs_lst = [0.0] * self.state_size, [0.0] * self.action_size, [0.0] * self.state_size
                    obs_lst[observation], act_lst[action], nxt_obs_lst[new_observation] = 1, 1, 1
                    offline_dataset['observation'].append(obs_lst)
                    offline_dataset['action'].append(act_lst)
                    offline_dataset['new_observation'].append(nxt_obs_lst)
                    
                    offline_dataset['reward'].append(reward)
                    offline_dataset['cost'].append(cost)
                    offline_dataset['hole'].append(hole)
                    offline_dataset['goal'].append(goal)
                    offline_dataset['is_init'].append(is_init)
                
                observation = new_observation
                time_step += 1
        
        print(f'There are {goal_count} data reaching the goal in {len(offline_dataset["observation"])} data.')
        if alg == 'Importance_Sampling':
            return offline_dataset, mu_D_count, r_s_a, c_s_a, P, mu_0
        else:
            return offline_dataset
        
        
        
class Imitate11x11Env:
    def __init__(self, ncol=11, nrow=11, gamma=0.9, cost_limit=0):
        
        self.ncol = ncol  
        self.nrow = nrow  
        
        self.state_size = ncol * nrow
        self.action_size = 4
        self.state_dim = 1
        self.action_dim = 1
        
        self.gamma = gamma
        self.cost_limit = cost_limit
        self.collect_time_step = 50
        self.test_time_step = 25

        self.goal_state = []
        self.hole_state = [63, 93]
        self.wall_state = [5, 16, 38, 49, 55, 56, 58, 59, 60, 61, 62, 64, 65, 71, 82, 104, 115]
        self.imit_state = [12, 13, 14, 23, 25, 34, 35, 36,\
                           18, 19, 20, 29, 31, 40, 41, 42,\
                           78, 79, 80, 89, 91, 100, 101, 102,\
                           84, 85, 86, 95, 97, 106, 107, 108]

        self.step = self.createP()
        
        self.target_dist = np.zeros(self.state_size*self.action_size)
        left, down, right, up = 0, 1, 2, 3
        lst = [[down, right], [left, right], [left, down],\
            [up, down], [up, down], \
            [up, right], [left, right], [left, up]]
        for i in range(4):
            for j, row in enumerate(lst):
                self.target_dist[self.imit_state[i*(len(self.imit_state)//4)+j]*self.action_size + row[0]] = 100
                self.target_dist[self.imit_state[i*(len(self.imit_state)//4)+j]*self.action_size + row[1]] = 100
        self.target_dist += 10
        self.target_dist = self.target_dist / np.sum(self.target_dist)


    def createP(self):
        P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]
        change = [[0, -1], [1, 0], [0, 1], [-1, 0]]
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(self.action_size):
                    state = i * self.ncol + j

                    next_i = min(self.nrow - 1, max(0, i + change[a][0]))
                    next_j = min(self.ncol - 1, max(0, j + change[a][1]))
                    next_state = next_i * self.ncol + next_j

                    cost = 0
                    hole = False

                    if next_state in self.hole_state:
                        cost = 1
                        hole = True
                        P[state][a] = (next_state, cost, hole)
                    elif next_state in self.wall_state:
                        P[state][a] = (state, cost, hole)
                    else:
                        P[state][a] = (next_state, cost, hole)
        return P
    
    def reset(self):
        return 0
    
    def getdataset(self, num_traj=5000):
        offline_dataset = {
            'observation': [],
            'action': [],
            'cost': [],
            'new_observation': [],
            'hole': []
        }
        mu_0 = np.zeros((self.state_size))
        mu_0[0] = 1
        mu_D_count = np.zeros(self.state_size*self.action_size)

        c_s_a = np.zeros(self.state_size*self.action_size)
        
        M = np.zeros((self.state_size, self.state_size*self.action_size))
        P = np.zeros((self.state_size, self.state_size*self.action_size))

        for _ in range(num_traj):  
            observation = self.reset()
            hole = False
            time_step = 0
            while time_step < self.collect_time_step:
                action = np.random.choice(range(self.action_size), size=1, p=[1/self.action_size for _ in range(self.action_size)])[0]
                
                new_observation, cost, hole = self.step[observation][action]
                s_a = observation*self.action_size + action
                mu_D_count[s_a] += 1
                P[new_observation, s_a] += 1

                c_s_a[s_a] = cost
                
                offline_dataset['observation'].append(observation)
                offline_dataset['action'].append(action)
                offline_dataset['cost'].append(cost)
                offline_dataset['new_observation'].append(new_observation)
                offline_dataset['hole'].append(hole)
                
                observation = new_observation
                time_step += 1
        
        print(f'There are {len(offline_dataset["observation"])} data.')
        return offline_dataset, mu_D_count, c_s_a, P, mu_0
    
class Imitate8x8Env:
    def __init__(self, ncol=8, nrow=8, gamma=0.9, cost_limit=0):
        
        self.ncol = ncol  
        self.nrow = nrow  
        
        self.state_size = ncol * nrow
        self.action_size = 4
        self.state_dim = 1
        self.action_dim = 1
        
        self.gamma = gamma
        self.cost_limit = cost_limit

        self.goal_state = 63
        self.wall_state = []
        self.hole_state = [3, 11, 19, 29, 35, 41, 42, 46, 49, 52, 54, 59]
        self.collect_time_step = 50
        self.test_time_step = 25
        self.imit_state = [0, 8, 16, 24, 32, 40, 48, 56, 57, 58, 50, 51, 43, 44, 45, 53, 61, 62, 63,\
            0, 1, 9, 17, 25, 26, 27, 28, 29, 30, 31, 39, 47, 55, 63,\
                0, 1, 2, 3, 4, 5, 6, 7, 15, 23, 31, 39, 47, 55, 63]

        self.step = self.createP()
        


    def createP(self):
        P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]
        change = [[0, -1], [1, 0], [0, 1], [-1, 0]]
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(self.action_size):
                    state = i * self.ncol + j

                    next_i = min(self.nrow - 1, max(0, i + change[a][0]))
                    next_j = min(self.ncol - 1, max(0, j + change[a][1]))
                    next_state = next_i * self.ncol + next_j

                    cost = 0
                    hole = False
                    goal = False

                    if next_state in self.hole_state:
                        cost = 1
                        hole = True
                        P[state][a] = (next_state, cost, goal, hole)
                    elif next_state in self.goal_state:
                        goal = True
                        P[state][a] = (next_state, cost, goal, hole)
                    else:
                        P[state][a] = (next_state, cost, goal, hole)
        return P
    
    def reset(self):
        return 0
    
    def getdataset(self, num_traj=5000):
        offline_dataset = {
            'observation': [],
            'action': [],
            'cost': [],
            'new_observation': [],
            'hole': []
        }
        mu_0 = np.zeros((self.state_size))
        mu_0[0] = 1
        mu_D_count = np.zeros(self.state_size*self.action_size)

        c_s_a = np.zeros(self.state_size*self.action_size)
        
        M = np.zeros((self.state_size, self.state_size*self.action_size))
        P = np.zeros((self.state_size, self.state_size*self.action_size))

        for _ in range(num_traj):  
            observation = self.reset()
            hole = False
            time_step = 0
            while time_step < self.collect_time_step:
                action = np.random.choice(range(self.action_size), size=1, p=[1/self.action_size for _ in range(self.action_size)])[0]
                
                new_observation, cost, hole = self.step[observation][action]
                s_a = observation*self.action_size + action
                mu_D_count[s_a] += 1
                P[new_observation, s_a] += 1

                c_s_a[s_a] = cost
                
                offline_dataset['observation'].append(observation)
                offline_dataset['action'].append(action)
                offline_dataset['cost'].append(cost)
                offline_dataset['new_observation'].append(new_observation)
                offline_dataset['hole'].append(hole)
                
                observation = new_observation
                time_step += 1
        
        print(f'There are {len(offline_dataset["observation"])} data.')
        return offline_dataset, mu_D_count, c_s_a, P, mu_0    
