import torch
from lpcmdp.algorithm.model import *
from torch.utils.data import DataLoader
import itertools
from tqdm import tqdm
from lpcmdp.env.FrozenLake import *
from lpcmdp.env.Taxi import *
from lpcmdp.algorithm.utils import *


class FrozenLake_DataCollector():

    def __init__(self,
                 env,
                 num_trajectories=2000,
                 expert_pi=None,
                 percent=0.0,
                 ) -> None:
        self.env = env
        self.num_trajectories = 0
        self.offline_dataset = {
            'observation': [],
            'action': [],
            'reward': [],
            'cost': [],
            'new_observation': [],
            'hole': [],
            'goal': [],
            'is_init': []
        }
        self.expert_pi = expert_pi
        self.percent = percent
        self.collect_data(self.env, num_trajectories, expert_pi, percent)

    def collect_data(self, env, num_trajectories=2000, expert_pi=None, percent=0.0):
        goal_count = 0
        with tqdm(total=num_trajectories) as pbar:
            for i in range(num_trajectories):
                observation = env.reset()
                done = False
                hole = False
                goal = False
                time_step = 0
                while not done and time_step < env.collect_time_step:

                    if expert_pi is not None and np.random.rand() < percent:
                        action = np.random.choice(range(env.action_size), size=1, p=expert_pi[observation])[0] # Changed
                    else:
                        action = np.random.choice(range(env.action_size), size=1, p=[1/env.action_size for _ in range(env.action_size)])[0]
                    
                    next_observation, reward, cost, done, hole = env.step[observation][action]

                    if done and not hole:
                        goal_count += 1
                        goal = True
                    if observation == 0:
                        is_init = 1
                    else:
                        is_init = 0

                    self.add_data(observation, action, reward, cost, next_observation, hole, goal, is_init)
                    
                    observation = next_observation
                    time_step += 1
                    
                pbar.update(1)
        print(f'There are {goal_count} data reaching the goal in {len(self.offline_dataset["observation"])} data.')
        self.num_trajectories += num_trajectories
    
    def add_data(self, observation, action, reward, cost, new_observation, hole, goal, is_init):
        self.offline_dataset['observation'].append(observation)
        self.offline_dataset['action'].append(action)
        self.offline_dataset['reward'].append(reward)
        self.offline_dataset['cost'].append(cost)
        self.offline_dataset['new_observation'].append(new_observation)
        self.offline_dataset['hole'].append(hole)
        self.offline_dataset['goal'].append(goal)
        self.offline_dataset['is_init'].append(is_init)
    
    def get_original_dataset(self):
        return self.offline_dataset
    
    def get_onehot_encode_dataset(self):
        dataset = {
            'observation': [],
            'action': [],
            'reward': [],
            'cost': [],
            'new_observation': [],
            'hole': [],
            'goal': [],
            'is_init': []
        }
        for i in range(len(self.offline_dataset['observation'])):
            obs_lst, act_lst, nxt_obs_lst = [0.0] * self.env.state_size, [0.0] * self.env.action_size, [0.0] * self.env.state_size
            obs_lst[self.offline_dataset['observation'][i]] = 1.0
            act_lst[self.offline_dataset['action'][i]] = 1.0
            nxt_obs_lst[self.offline_dataset['new_observation'][i]] = 1.0
            dataset['observation'].append(obs_lst)
            dataset['action'].append(act_lst)
            dataset['new_observation'].append(nxt_obs_lst)
            dataset['reward'].append(self.offline_dataset['reward'][i])
            dataset['cost'].append(self.offline_dataset['cost'][i])
            dataset['hole'].append(self.offline_dataset['hole'][i])
            dataset['goal'].append(self.offline_dataset['goal'][i])
            dataset['is_init'].append(self.offline_dataset['is_init'][i])
        return dataset

    def get_real_behave_policy(self):
        self.random = np.ones_like(self.expert_pi) / self.env.action_size
        self.expert_pi = np.array(self.expert_pi)
        # print(self.expert_pi)
        real_behave_policy = self.percent * self.expert_pi + (1-self.percent) * self.random
        real_behave_policy = real_behave_policy / np.sum(real_behave_policy, axis=1, keepdims=True)
        # print(f'Real behavior policy in start position in {1-self.percent} random percentage is: {real_behave_policy[0]}')
        return real_behave_policy
        
    
    def get_behave_policy_prob(self):
        behave_policy_prob = np.zeros((self.env.state_size, self.env.action_size))
        for i in range(len(self.offline_dataset['observation'])):
            s = self.offline_dataset['observation'][i]
            a = self.offline_dataset['action'][i]
            # print(s, a)
            behave_policy_prob[s][a] += 1
        for i in range(behave_policy_prob.shape[0]):
            if behave_policy_prob[i].sum() == 0:
                behave_policy_prob[i] = np.zeros_like(behave_policy_prob[i])
            else:
                behave_policy_prob[i] = behave_policy_prob[i] / behave_policy_prob[i].sum()
        return behave_policy_prob



class Taxi_DataCollector():
    def __init__(self,
                 env,
                 num_episodes=2000,
                 exper_pi=None,
                 percent=0.0,
                 ) -> None:
        self.env = env
        self.num_episodes = num_episodes
        self.num_trajectories = 0
        self.offline_dataset = {
            'observation': [],
            'action': [],
            'reward': [],
            'cost': [],
            'new_observation': [],
            'hole': [],
            'done': [],
            'goal': [],
            'is_init': []
        }
        self.expert_pi = exper_pi
        self.percent = percent
        self.collect_data(self.env, num_episodes, exper_pi, percent)
        
    def collect_data(self, env, num_episodes=2000, expert_pi=None, percent=0.0):
        goal_count = 0
        with tqdm(total=num_episodes) as pbar:
            for i in range(num_episodes):
                state = env.reset()
                is_init = 1
                done = False
                hole = False
                time_step = 0
                while not done and time_step < env.collect_time_step:
                    
                    if expert_pi is not None and np.random.rand() < percent:
                        # print(expert_pi[state])
                        action = np.random.choice(range(env.action_size), size=1, p=expert_pi[state])[0]
                    else:
                        action = np.random.choice(range(env.action_size), size=1, p=[1/env.action_size for _ in range(env.action_size)])[0]
                    next_state, reward, cost, done, hole = env.step[state][action]
                    next_locate = next_state // (4*5)
                    p_d = state % (4 * 5)
                    passenger = p_d // 4
                    goal = p_d % 4

                    
                    if done and not hole:
                        next_locate = next_state // (4*5)
                        p_d = state % (4 * 5)
                        passenger = p_d // 4
                        goal = p_d % 4
                        if next_locate == env.init_locate[goal] and action == 5:
                            goal_count += 1
                    
                    self.add_data(state, action, reward, cost, next_state, hole, done, is_init)
                    
                    is_init = 0
                    state = next_state
                    time_step += 1
                    
                pbar.update(1)
        print(f'There are {goal_count} data reaching the goal in {len(self.offline_dataset["observation"])} data')
        self.num_trajectories += num_episodes
    
    def add_data(self, state, action, reward, cost, next_state, hole, done, is_init):
        self.offline_dataset['observation'].append(state)
        self.offline_dataset['action'].append(action)
        self.offline_dataset['reward'].append(reward)
        self.offline_dataset['cost'].append(cost)
        self.offline_dataset['new_observation'].append(next_state)
        self.offline_dataset['hole'].append(hole)
        self.offline_dataset['done'].append(done)
        self.offline_dataset['goal'].append(0)
        self.offline_dataset['is_init'].append(is_init)
        
    def get_original_dataset(self):
        return self.offline_dataset
    
    def get_onehot_encode_dataset(self):
        dataset = {
            'observation': [],
            'action': [],
            'reward': [],
            'cost': [],
            'new_observation': [],
            'hole': [],
            'done': [],
            'goal': [],
            'is_init': []
        }
        for i in range(len(self.offline_dataset['observation'])):
            state_lst, act_lst, nxt_state_lst = [0.0] * self.env.state_size, [0.0] * self.env.action_size, [0.0] * self.env.state_size
            state_lst[self.offline_dataset['observation'][i]] = 1.0
            act_lst[self.offline_dataset['action'][i]] = 1.0
            nxt_state_lst[self.offline_dataset['new_observation'][i]] = 1.0
            dataset['observation'].append(state_lst)
            dataset['action'].append(act_lst)
            dataset['reward'].append(self.offline_dataset['reward'][i])
            dataset['cost'].append(self.offline_dataset['cost'][i])
            dataset['new_observation'].append(nxt_state_lst)
            dataset['hole'].append(self.offline_dataset['hole'][i])
            dataset['done'].append(self.offline_dataset['done'][i])
            dataset['goal'].append(self.offline_dataset['goal'][i])
            dataset['is_init'].append(self.offline_dataset['is_init'][i])
        return dataset
    
    def get_real_behave_policy(self):
        # print(self.expert_pi)
        self.random = np.ones_like(self.expert_pi) / self.env.action_size
        self.expert_pi = np.array(self.expert_pi)
        
        real_behave_policy = self.percent * self.expert_pi + (1 - self.percent) * self.random
        real_behave_policy = real_behave_policy / np.sum(real_behave_policy, axis=1, keepdims=True)
        
        return real_behave_policy
    
    def get_behave_policy_prob(self):
        behave_policy_prob = np.zeros((self.env.state_size, self.env.action_size))
        for i in range(len(self.offline_dataset['observation'])):
            s = self.offline_dataset['observation'][i]
            a = self.offline_dataset['action'][i]
            # print(s, a)
            behave_policy_prob[s][a] += 1
        for i in range(behave_policy_prob.shape[0]):
            if behave_policy_prob[i].sum() == 0:
                behave_policy_prob[i] = np.zeros_like(behave_policy_prob[i])
            else:
                behave_policy_prob[i] = behave_policy_prob[i] / behave_policy_prob[i].sum()
        return behave_policy_prob
    

        

if __name__ == '__main__':
    env = Taxi_nocost()
    data_collector = Taxi_DataCollector(env, num_episodes=100, percent=0.0)
    # print(data_collector.get_original_data())
    
    
    # env = FrozenLakeEnv_nocost(ncol=8, nrow=8)
    # exp = ValueIteration(env, 0.01, env.gamma)
    # exp.value_iteration()
    # # plot_policy(np.array(exp.pi), env, "w")
    # datacollector = FrozenLake_DataCollector(env, num_trajectories=2, expert_pi=exp.pi, percent=1.0)
    # print(datacollector.get_original_dataset())