import torch
import numpy as np
import pdb
import random
import urllib.request
import os

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data,\
        normalize_states = False,\
        normalize_rewards = False,\
        normalize_actions = False,
        normalize_state_actions = False,\
        eps = 1e-5,
        tabular = False,
        pie = None,
        skip_rate = 1):

        dataset = data['dataset']

        self.pie = pie
        self.curr_states = dataset['state_b'].astype(np.float32)[0::skip_rate]
        self.curr_actions = dataset['action_b'].astype(np.float32)[0::skip_rate]
        self.next_states = dataset['next_state_b'].astype(np.float32)[0::skip_rate]
        self.rewards = dataset['rewards'].astype(np.float32)[0::skip_rate]
        self.initial_states = dataset['init_state'].astype(np.float32)
        self.terminal_masks = dataset['terminal_masks'].astype(np.float32)[0::skip_rate]
        self.num_samples = dataset['num_samples'] // skip_rate
        self.q_values = dataset['q_values'].astype(np.float32) if 'q_values' in dataset else np.zeros(self.num_samples)
        self.batch_size = np.count_nonzero(1. - self.terminal_masks)
        self.eps = eps
        self.normalize_states_flag = normalize_states
        self.normalize_rewards_flag = normalize_rewards
        self.tabular = tabular
        self.skip_rate = skip_rate
        if self.tabular:
            if 'state' in dataset:
                self.curr_raw_state = dataset['state']
                self.next_raw_state = dataset['next_state']
            self.curr_state_actions = dataset['state_action_b'].astype(np.float32)
            self.next_state_actions = dataset['next_state_action'].astype(np.float32)
            self.init_state_actions = dataset['init_state_action'].astype(np.float32)
            self.init_state_actions_curr = dataset['init_state_action_b'].astype(np.float32)

            count = np.zeros((self.curr_state_actions.shape[1]))
            for sa in self.curr_state_actions:
                idx = np.argmax(sa)
                count[idx] += 1.
            self.sa_visitation = count / self.curr_state_actions.shape[0]
        else:
            self.curr_state_actions = np.concatenate((self.curr_states, self.curr_actions), axis = 1)

        # TODO normalize actions??
        if normalize_states:
            self.state_mean = np.mean(self.curr_states, axis = 0)
            self.state_std = np.std(self.curr_states, axis = 0)

            self.curr_states = self.normalize_states(self.curr_states)
            self.next_states = self.normalize_states(self.next_states)
            self.initial_states = self.normalize_states(self.initial_states)
        else:
            self.state_mean = 0.
            self.state_std = 1.
        
        if normalize_rewards:
            self.reward_mean = np.mean(self.rewards)
            self.reward_std = np.std(self.rewards)

            self.rewards = self.normalize_rewards(self.rewards)
        else:
            self.reward_mean = 0.
            self.reward_std = 1.
        
        if normalize_state_actions and self.tabular:
            self.state_action_mean = np.mean(self.curr_state_actions, axis = 0)
            self.state_action_std = np.std(self.curr_state_actions, axis = 0)

            self.curr_state_actions = self.normalize_state_actions(self.curr_state_actions)
            self.next_state_actions = self.normalize_state_actions(self.next_state_actions)
            self.init_state_actions = self.normalize_state_actions(self.init_state_actions)

        self.min_reward = np.min(self.rewards)
        self.max_reward = np.max(self.rewards)

        self.min_abs_reward_diff = 0
        self.max_abs_reward_diff = np.abs(self.max_reward - self.min_reward)

        print ('reward stats mean {}, std {}, max {}, min {}'.format(self.rewards.mean(), self.rewards.std(),\
            self.rewards.max(), self.rewards.min()))

        #self.sarsa_data = self._get_sarsa_dataset()

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        
        reward = self.rewards[index]
        terminal_mask = self.terminal_masks[index]

        if self.tabular:
            curr_sa = self.curr_state_actions[index]
            next_state = self.next_states[index]
            next_sa = self.pie.sample_sa_features(next_state)
        else:
            curr_sa = self.curr_state_actions[index]
            next_state = self.next_states[index]
            pie_next_action = self.pie.batch_sample(self.unnormalize_states(next_state))
            next_sa = np.concatenate((next_state, pie_next_action), axis = 0)

        data = {
            'curr_sa': curr_sa,
            'next_sa': next_sa,
            'next_state': next_state,
            'rewards': reward,
            'terminal_masks': terminal_mask,
            'index': index
        }
        return data

    def _get_sarsa_dataset(self):

        curr_states = np.split(self.curr_states, self.batch_size)
        curr_actions = np.split(self.curr_actions, self.batch_size)
        next_states = np.split(self.next_states, self.batch_size)
        rewards = np.split(self.rewards, self.batch_size)

        c_s = []
        c_a = []
        n_s = []
        n_a = []
        r = []
        for s_, a_, r_, sn_ in zip(curr_states, curr_actions, rewards, next_states):
            an = a_[1:] # remove first
            new_sn = sn_[:-1] # remove last
            new_s = s_[:-1]
            new_a = a_[:-1]
            new_r = r_[:-1]

            c_s.append(new_s)
            c_a.append(new_a)
            n_s.append(new_sn)
            n_a.append(an)
            r.append(new_r)

        c_s = np.array(c_s)
        c_a = np.array(c_a)
        n_s = np.array(n_s)
        n_a = np.array(n_a)
        r = np.array(r)

        sarsa = {
            'curr_states': np.vstack(c_s),
            'curr_actions': np.vstack(c_a),
            'next_states': np.vstack(n_s),
            'next_actions': np.vstack(n_a),
            'rewards': np.hstack(r)
        }
        return sarsa

    def normalize_state_actions(self, state_actions):
        normalized_state_actions = (state_actions - self.state_action_mean) / np.maximum(self.eps, self.state_action_std)
        return normalized_state_actions
    
    def unnormalize_state_actions(self, normalized_state_actions):
        state_actions = normalized_state_actions * np.maximum(self.eps, self.state_action_std) + self.state_action_mean
        return state_actions

    def normalize_states(self, states):
        normalized_states = (states - self.state_mean) / np.maximum(self.eps, self.state_std)
        return normalized_states
    
    def unnormalize_states(self, normalized_states):
        states = normalized_states * np.maximum(self.eps, self.state_std) + self.state_mean
        return states

    def normalize_rewards(self, rewards):
        normalized_rewards = (rewards - self.reward_mean) / np.maximum(self.eps, self.reward_std)
        return normalized_rewards
    
    def unnormalize_rewards(self, normalized_rewards):
        rewards = normalized_rewards * np.maximum(self.eps, self.reward_std) + self.reward_mean
        return rewards
    
    # used by OPE methods for evaluation
    def get_initial_states_samples(self, mini_batch_size):
        # subsamples = np.random.choice(self.num_samples, mini_batch_size, replace = False)
        # initial_states = self.initial_states[subsamples]
        # return initial_states
        return self.initial_states

    def store_pie_info(self, pie_val, pie_path_sa_vals, pie_path_states, pie_path_acts):
        self.pie_val = pie_val
        self.pie_path_states = pie_path_states
        self.pie_path_acts = pie_path_acts
        if pie_path_sa_vals is not None:
            self.pie_path_sa_vals = pie_path_sa_vals.reshape(-1)
        if pie_path_states is not None and pie_path_acts is not None:
            self.pie_path_sa = np.concatenate((pie_path_states, pie_path_acts), axis = 1)

class PWDataset(torch.utils.data.Dataset):
    def __init__(self, data,\
        normalize_states = False,\
        normalize_rewards = False,\
        normalize_actions = False,
        normalize_state_actions = False,\
        eps = 1e-5,
        tabular = False,
        pie = None):

        dataset = data['dataset']

        self.pie = pie
        self.curr_states = dataset['state_b'].astype(np.float32)
        self.curr_actions = dataset['action_b'].astype(np.float32)
        self.next_states = dataset['next_state_b'].astype(np.float32)
        self.rewards = dataset['rewards'].astype(np.float32)
        self.initial_states = dataset['init_state'].astype(np.float32)
        self.terminal_masks = dataset['terminal_masks'].astype(np.float32)
        self.num_samples = dataset['num_samples']
        self.q_values = dataset['q_values'].astype(np.float32) if 'q_values' in dataset else np.zeros(self.num_samples)
        self.batch_size = np.count_nonzero(1. - self.terminal_masks)
        self.tabular = tabular
        if self.tabular:
            self.curr_state_actions = dataset['state_action_b'].astype(np.float32)
            self.next_state_actions = dataset['next_state_action'].astype(np.float32)
            self.init_state_actions = dataset['init_state_action'].astype(np.float32)
            if 'state' in dataset:
                self.curr_raw_state = dataset['state']
                self.next_raw_state = dataset['next_state']
        else:
            self.curr_state_actions = np.concatenate((self.curr_states, self.curr_actions), axis = 1)

        dataset = data['other_dataset']

        self.other_curr_states = dataset['state_b'].astype(np.float32)
        self.other_curr_actions = dataset['action_b'].astype(np.float32)
        self.other_next_states = dataset['next_state_b'].astype(np.float32)
        self.other_rewards = dataset['rewards'].astype(np.float32)
        self.other_terminal_masks = dataset['terminal_masks'].astype(np.float32)
        self.other_q_values = dataset['q_values'].astype(np.float32) if 'q_values' in dataset else np.zeros(self.num_samples)
        self.tabular = tabular
        if self.tabular:
            self.other_curr_state_actions = dataset['state_action_b'].astype(np.float32)
            self.other_next_state_actions = dataset['next_state_action'].astype(np.float32)
            if 'state' in dataset:
                self.other_curr_raw_state = dataset['state']
                self.other_next_raw_state = dataset['next_state']
        else:
            self.other_curr_state_actions = np.concatenate((self.other_curr_states, self.other_curr_actions), axis = 1)

        self.min_reward = np.min(self.rewards)
        self.max_reward = np.max(self.rewards)

        self.min_abs_reward_diff = 0
        self.max_abs_reward_diff = np.abs(self.max_reward - self.min_reward)

        print ('reward stats mean {}, std {}, max {}, min {}'.format(self.rewards.mean(), self.rewards.std(),\
            self.rewards.max(), self.rewards.min()))

        #self.sarsa_data = self._get_sarsa_dataset()

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        
        reward = self.rewards[index]
        other_reward = self.other_rewards[index]
        terminal_mask = self.terminal_masks[index]
        other_terminal_mask = self.other_terminal_masks[index]

        if self.tabular:
            curr_sa = self.curr_state_actions[index]
            next_state = self.next_states[index]
            next_sa = self.pie.sample_sa_features(next_state)

            other_curr_sa = self.other_curr_state_actions[index]
            other_next_state = self.other_next_states[index]
            other_next_sa = self.pie.sample_sa_features(other_next_state)

        data = {
            'curr_sa': curr_sa,
            'next_sa': next_sa,
            'next_state': next_state,
            'rewards': reward,
            'terminal_masks': terminal_mask,
            'other_curr_sa': other_curr_sa,
            'other_next_sa': other_next_sa,
            'other_next_state': other_next_state,
            'other_rewards': other_reward,
            'other_terminal_masks': other_terminal_mask,
            'index': index
        }
        return data
    
    # used by OPE methods for evaluation
    def get_initial_states_samples(self, mini_batch_size):
        # subsamples = np.random.choice(self.num_samples, mini_batch_size, replace = False)
        # initial_states = self.initial_states[subsamples]
        # return initial_states
        return self.initial_states

    def store_pie_info(self, pie_val, pie_path_sa_vals, pie_path_states, pie_path_acts):
        self.pie_val = pie_val
        self.pie_path_states = pie_path_states
        self.pie_path_acts = pie_path_acts
        if pie_path_sa_vals is not None:
            self.pie_path_sa_vals = pie_path_sa_vals.reshape(-1)
        if pie_path_states is not None and pie_path_acts is not None:
            self.pie_path_sa = np.concatenate((pie_path_states, pie_path_acts), axis = 1)
