import numpy as np
import pdb
import copy

class SimilarityMetric:
    def __init__(self, mdp, typ, algo, pie, pib, gamma):
        self.typ = typ
        self.algo = algo
        self.mdp = mdp
        self.pie = pie
        self.pib = pib
        self.gamma = gamma
    
        if self.typ == 'state':
            self.distances = np.zeros((mdp.n_state, mdp.n_state))
        elif self.typ == 'state-action':
            self.distances = np.zeros((mdp.n_state * mdp.n_action, mdp.n_state * mdp.n_action))
    
    def learn(self, data, epochs = 100, lr = 1e-1):

        states = data['states']
        rews = data['rewards']
        acts = data['actions']
        next_states = data['next_states']
        dones = data['dones']
        num_samples = len(states)

        prev_distances = np.zeros_like(self.distances)
        updated_tuples = np.zeros_like(self.distances).astype(bool)

        for epoch in range(epochs):
            sub_samples1 = np.random.choice(num_samples, num_samples, replace = False)
            sub_samples2 = np.random.choice(num_samples, num_samples, replace = False)
            
            for idx1, idx2 in zip(sub_samples1, sub_samples2):
                x, ax, rx, nx = states[idx1], acts[idx1], rews[idx1], next_states[idx1]
                y, ay, ry, ny = states[idx2], acts[idx2], rews[idx2], next_states[idx2]

                target = self._get_algo_target(x, ax, rx, nx, y, ay, ry, ny)
                if self.typ == 'state-action':
                    sa1_cord = x * self.mdp.n_action + ax
                    sa2_cord = y * self.mdp.n_action + ay
                    self.distances[sa1_cord][sa2_cord] += lr * (target \
                        - self.distances[sa1_cord][sa2_cord])
                    updated_tuples[sa1_cord][sa2_cord] = True
            
            if ((epoch + 1) % 100 == 0):
                diff = np.abs(self.distances - prev_distances)
                thresh = (diff <= 1e-5)
                count = np.count_nonzero(thresh)
                if count == self.distances.shape[0] * self.distances.shape[1]:
                    print ('converged, done, itr {}'.format(epoch + 1))
                    break
                prev_distances = copy.deepcopy(self.distances)
                print (np.linalg.norm(self.distances))
                lr /= 2.
        self.distances[(1 - updated_tuples).astype(bool)] = -np.inf
        return self.distances
    
    def _get_algo_target(self, x, ax, rx, nx, y, ay, ry, ny):
        if self.algo == 'pie-SA-MICO':
            val = np.abs(rx - ry)
            for a1 in range(self.mdp.n_action):
                for a2 in range(self.mdp.n_action):
                    sa1_cord = nx * self.mdp.n_action + a1
                    sa2_cord = ny * self.mdp.n_action + a2
                    val += self.pie.get_prob(nx, a1) * self.pie.get_prob(ny, a2) * self.gamma * self.distances[sa1_cord][sa2_cord]
        elif self.algo == 'pie-single-SA-MICO':
            val = np.abs(rx - ry)
            for a in range(self.mdp.n_action):
                sa1_cord = nx * self.mdp.n_action + a
                sa2_cord = ny * self.mdp.n_action + a
                val += self.pie.get_prob(nx, a) * self.pie.get_prob(ny, a) * self.gamma * self.distances[sa1_cord][sa2_cord]
        elif self.algo == 'pib-SA-MICO':
            val = np.abs(rx - ry)
            for a1 in range(self.mdp.n_action):
                for a2 in range(self.mdp.n_action):
                    sa1_cord = nx * self.mdp.n_action + a1
                    sa2_cord = ny * self.mdp.n_action + a2
                    val += self.pib.get_prob(nx, a1) * self.pib.get_prob(ny, a2) * self.gamma * self.distances[sa1_cord][sa2_cord]
        elif self.algo == 'rand-SA':
            val = np.abs(rx - ry)
            for a in range(self.mdp.n_action):
                sa1_cord = nx * self.mdp.n_action + a
                sa2_cord = ny * self.mdp.n_action + a
                val += (1. / self.mdp.n_action) * self.gamma * self.distances[sa1_cord][sa2_cord]
        elif self.algo == 'pie-SA-PSM':
            val = np.abs(self.pie.get_prob(x, ax) - self.pie.get_prob(y, ay))
            for a1 in range(self.mdp.n_action):
                for a2 in range(self.mdp.n_action):
                    sa1_cord = nx * self.mdp.n_action + a1
                    sa2_cord = ny * self.mdp.n_action + a2
                    val += self.pie.get_prob(nx, a1) * self.pie.get_prob(ny, a2) * self.gamma * self.distances[sa1_cord][sa2_cord]

        return val
