import numpy as np
import pdb

class DivergenceMDPs:

    diverage_mdp = True

    def reset(self):
        self.state = self.init_state
        #return self.state_to_features[self.init_state], {}
        return self.init_state, {}

    def step(self, action):
        states = np.arange(self.n_state + 1)
        rew = self.rewards[self.state, action]
        try:
            n_state = np.random.choice(states,
                                    p=self.transitions[self.state, action])
        except:
            pdb.set_trace()
        self.state = n_state

        done = self.state == self.n_state
        #n_state_feat = self.state_to_features[n_state]
        return n_state, rew, done, False, {}

    def parse_dataset(self, paths, num_trans):
        states, n_states, rews, dones, init_state = [], [], [], [], []
        for path in paths:
            for idx, (s, ns, r, d) in enumerate(zip(path['obs'], path['nobs'], path['rews'], path['dones'])):
                states.append(s)
                n_states.append(ns)
                rews.append(r)
                dones.append(d)
                if idx == 0: init_state.append(s)

        on_policy_data = {
            'states': np.array(states),
            'next_states': np.array(n_states),
            'rewards': np.array(rews),
            'dones': np.array(dones),
            'init_states': np.array(init_state)
        }
        idx = np.random.choice(on_policy_data['states'].shape[0], num_trans, replace = False)
        for key in on_policy_data:
            if 'init' in key:
                on_policy_data[key] = on_policy_data[key]
            else:
                on_policy_data[key] = on_policy_data[key][idx]
        return on_policy_data

    def merge_data(self, on, off, mix_ratio):
        if mix_ratio == 1.:
            return off
        elif mix_ratio == 0.:
            return on
        num_on = int(np.ceil((1 - mix_ratio) * on['states'].shape[0]))
        num_off = on['states'].shape[0] - num_on
        print ('num off samples added ', num_off)

        # random off samples to add
        off_indices = np.random.choice(off['states'].shape[0], num_off, replace = False)
        # for every off added, remove a random on
        on_indices = set(np.random.choice(on['states'].shape[0], num_off, replace = False))
        on_indices_include = np.array([i for i in range(on['states'].shape[0]) if i not in on_indices])

        merged = {}
        for key in on:
            if 'init' in key:
                merged[key] = on[key]
            else:
                merged[key] = np.concatenate((on[key][on_indices_include], off[key][off_indices]))
        return merged

    def count(self, data):
        array = np.zeros(self.n_state)
        for s in data['states']:
            array[s] += 1
        print (array / np.sum(array))
        array = array / np.sum(array)

        counts = {}
        for s, ns in zip(data['states'], data['next_states']):
            key = tuple(sorted((s, ns)))
            if key not in counts:
                counts[key] = 0
            counts[key] += 1
        total = sum([counts[k] for k in counts])

        print ('prob of trans')
        for k1 in sorted(counts):
            print (k1, (counts[k1] / total))

        print ('prob of pair of')
        pairs = {}
        for k1 in sorted(counts):
            for k2 in sorted(counts):
                print (k1, k2, (counts[k1] / total) * (counts[k2] / total))
                pair = (k1, k2)
                pairs[pair] = (counts[k1] / total) * (counts[k2] / total)
    
        return array, pairs
    
    def ratios(self, target, other):
        for idx, i in enumerate(target):
            if isinstance(i, tuple):
                print (i, target[i] / other[i])
            else:
                print (i, i / other[idx])
    
    def generate_pairwise_dataset(self, pw_distribution, num_trans = 1000):

        pairs = sorted(pw_distribution.keys())
        probs = np.array([pw_distribution[pair] for pair in pairs])
        sampled_pairs = np.random.choice(np.arange(len(pairs)), p=probs, size = num_trans)

        # generate separate datasets based on the pw distribution

        rews, other_rews = [], []
        dones, other_dones = [], []
        states, other_states = [], []
        n_states, other_n_states = [], []
        for sp in sampled_pairs:
            trans1, trans2 = pairs[sp]
            
            curr_sa = trans1[0]
            next_sa = trans1[1]
            rew = self.rewards[curr_sa, 0]
            states.append(curr_sa)
            n_states.append(next_sa)
            rews.append(rew)
            dones.append(next_sa == self.n_state)

            other_curr_sa = trans2[0]
            other_next_sa = trans2[1]
            other_rew = self.rewards[other_curr_sa, 0]
            other_states.append(other_curr_sa)
            other_n_states.append(other_next_sa)
            other_rews.append(other_rew)
            other_dones.append(other_next_sa == self.n_state)
        
        states = np.array(states)
        other_states = np.array(other_states)
        n_states = np.array(n_states)
        other_n_states = np.array(other_n_states)
        rews = np.array(rews)
        other_rews = np.array(other_rews)
        dones = np.array(dones)
        other_dones = np.array(other_dones)

        init_state = np.array([0 for _ in range(num_trans)])

        assert len(states) == len(other_states)
        assert len(n_states) == len(other_n_states)
        assert len(rews) == len(other_rews)
        assert len(dones) == len(other_dones)

        data = {
            'states': np.array(states),
            'next_states': np.array(n_states),
            'rewards': np.array(rews),
            'dones': np.array(dones),
            'init_states': np.array(init_state)
        }

        other_data = {
            'states': np.array(other_states),
            'next_states': np.array(other_n_states),
            'rewards': np.array(other_rews),
            'dones': np.array(other_dones),
            'init_states': np.array(init_state)
        }
        return data, other_data
    
    def merge_feature_datasets(self, dataset, other_dataset):
        combined = {
            'dataset': {},
            'other_dataset': {}
        }
        combined['dataset'].update(dataset['dataset'])
        for key in other_dataset['dataset']:
            combined['other_dataset'][key] = other_dataset['dataset'][key]
        return combined

    def add_bad_transitions(self, on_policy_pw, d1, d2, sample_num = 0, mix_ratio = 3000):

        # pairs = []
        # for s1 in range(self.n_state):
        #     for s2 in range(self.n_state):
        #         bad_pair = (s1, s2) == (2, 2) or (s1, s2) == (2, 3) or (s1, s2) == (3, 2)
        #         if both_bad and bad_pair:
        #             pairs.append((s1, s2))
        #         elif not both_bad and not bad_pair and 2 in (s1, s2):
        #             pairs.append((s1, s2))
        
        # (2,0), (2,1), (2,2), (2,3)
        state = 2
        other_state = sample_num # 0, 1, 2, 3
        mix_ratio = int(mix_ratio)
        for _ in range(mix_ratio):
            # sample_idx = np.random.choice(np.arange(len(pairs)))
            # state, other_state = 2, 2#pairs[sample_idx][0], pairs[sample_idx][1]

            d1['states'] = np.append(d1['states'], state)
            n_state = np.random.choice(np.arange(self.n_state + 1), p=self.transitions[state, 0])
            d1['next_states'] = np.append(d1['next_states'], n_state)
            d1['rewards'] = np.append(d1['rewards'], self.rewards[state, 0])
            d1['dones'] = np.append(d1['dones'], n_state == self.n_state)
            d1['init_states'] = np.append(d1['init_states'], 0)

            d2['states'] = np.append(d2['states'], other_state)
            n_state = np.random.choice(np.arange(self.n_state + 1), p=self.transitions[other_state, 0])
            d2['next_states'] = np.append(d2['next_states'], n_state)
            d2['rewards'] = np.append(d2['rewards'], self.rewards[other_state, 0])
            d2['dones'] = np.append(d2['dones'], n_state == self.n_state)
            d2['init_states'] = np.append(d2['init_states'], 0)

        return d1, d2

# https://arxiv.org/pdf/1905.10506
class ModifiedRoy(DivergenceMDPs):

    def __init__(self):
        self.n_state = self.num_states = 4
        self.n_action = self.num_actions = 1
        self.init_state = 0
        self.state = self.init_state
        self.transitions = np.zeros((self.num_states, self.num_actions, self.num_states + 1))

        self.transitions[0, 0, 1] = 0.8
        self.transitions[0, 0, 2] = 0.2
        self.transitions[1, 0, self.num_states] = 1.
        self.transitions[2, 0, 3] = 1.
        self.transitions[3, 0, 3] = 0.9
        self.transitions[3, 0, self.num_states] = 0.1

        self.rewards = np.zeros((self.num_states, self.num_actions))
        self.rewards[0, 0] = 0
        self.rewards[1, 0] = 1.
        self.rewards[2, 0] = 0
        self.rewards[3, 0] = 0

        self.state_action_features = {
            0: [1., 0., 0.],
            1: [0., 1., 0.],
            2: [0., 0., 1.],
            3: [0., 0., 2.],
            4: [0., 0., 0.]
        }

        # self.features = {
        #     0: [1., 0., 0., 0.],
        #     1: [0., 1., 0., 0.],
        #     2: [0., 0., 1., 0.],
        #     3: [0., 0., 0., 1.],
        #     4: [0., 0., 0., 0.]
        # }

        self.observation_space = np.zeros((len(self.state_action_features[0]),))
        self.action_space = np.zeros((self.n_action,))

    def get_dataset(self, num_trans, off_type = 'bad'):
        all_states = np.arange(self.n_state + 1)
        if off_type == 'bad':
            x = np.array([0. for _ in range(self.n_state)])
            x[2] = 1.
        elif off_type == 'uniform':
            x = np.array([1. for _ in range(self.n_state)])
        elif off_type == 'random':
            x = np.abs(np.random.rand(self.num_states))
        x = x / np.sum(x)

        states = np.random.choice(self.num_states, size = num_trans, p = x)
        rews = self.rewards[states, 0]
        dones = []
        n_states = []
        for st in states:
            n_state = np.random.choice(all_states, p=self.transitions[st, 0])
            n_states.append(n_state)
            dones.append(n_state == self.n_state)
        n_states = np.array(n_states)
        dones = np.array(dones)

        init_state = np.array([0 for _ in range(num_trans)])

        dataset = {
            'states': states,
            'next_states': n_states,
            'rewards': rews,
            'dones': dones,
            'init_states': init_state
        }
        return dataset

class Bairds(DivergenceMDPs):

    def __init__(self):
        self.n_state = self.num_states = 7
        self.n_action = self.num_actions = 1
        self.init_state = np.random.randint(0, self.n_state)
        self.state = self.init_state
        self.transitions = np.zeros((self.num_states, self.num_actions, self.num_states))

        # top always go down
        self.transitions[np.arange(0, self.num_states - 1), 0, self.num_states - 1] = 1.
        # once bottom, always stay bottom
        self.transitions[self.num_states - 1, 0, self.num_states - 1] = 1.

        self.rewards = np.zeros((self.num_states, self.num_actions))
        # half of them have rewards one (to allow for krope's reward normalization)
        self.rewards[:self.num_states // 2] = 1.
        #self.rewards[self.num_states - 1, 0] = 0.

        self.state_action_features = {
            0: [2., 0., 0., 0., 0., 0., 0., 1.],
            1: [0., 2., 0., 0., 0., 0., 0., 1.],
            2: [0., 0., 2., 0., 0., 0., 0., 1.],
            3: [0., 0., 0., 2., 0., 0., 0., 1.],
            4: [0., 0., 0., 0., 2., 0., 0., 1.],
            5: [0., 0., 0., 0., 0., 2., 0., 1.],
            6: [0., 0., 0., 0., 0., 0., 1., 2.],
        }

        self.observation_space = np.zeros((len(self.state_action_features[0]),))
        self.action_space = np.zeros((self.n_action,))
    
    def reset(self):
        self.state = np.random.randint(0, self.n_state) 
        #return self.state_to_features[self.init_state], {}
        return self.state, {}

    def step(self, action):
        states = np.arange(self.n_state)
        rew = self.rewards[self.state, action]
        try:
            n_state = np.random.choice(states,
                                    p=self.transitions[self.state, action])
        except:
            pdb.set_trace()
        self.state = n_state

        done = False
        #n_state_feat = self.state_to_features[n_state]
        return n_state, rew, done, False, {}

    def get_dataset(self, num_trans, off_type = 'uniform'):
        all_states = np.arange(self.n_state)
        
        # uniformly sampling
        x = np.array([1. for _ in range(self.n_state)])
        x = x / np.sum(x)

        states = np.random.choice(self.num_states, size = num_trans, p = x)
        rews = self.rewards[states, 0]
        dones = []
        n_states = []
        for st in states:
            n_state = np.random.choice(all_states, p=self.transitions[st, 0])
            n_states.append(n_state)
            dones.append(False)
        n_states = np.array(n_states)
        dones = np.array(dones)

        init_state = np.array([np.random.randint(0, self.n_state) for _ in range(num_trans)])

        dataset = {
            'states': states,
            'next_states': n_states,
            'rewards': rews,
            'dones': dones,
            'init_states': init_state
        }
        return dataset

# env = Bairds()
# state, _ = env.reset()
# for i in range(10):
#     next_state, rew, done, _, _ = env.step(0)
#     print (state, next_state, rew, done)
#     state = next_state
#     if done:
#         break

