from abc import ABC, abstractmethod
import torch

from torch_ac.format import default_preprocess_obss
from torch_ac.utils import DictList, ParallelEnv

import torch.nn as nn

import numpy as np
from skimage.util.shape import view_as_windows

from sip import *

def max_value(values):
    mv = values[0]
    for v in values:
        mv = max(mv, v)
    return mv

def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)

class BaseAlgo(ABC):
    """The base class for RL algorithms."""

    def __init__(self, envs, acmodel, device, num_frames_per_proc, discount, lr, gae_lambda, entropy_coef,
                 value_loss_coef, max_grad_norm, recurrence, preprocess_obss, reshape_reward, 
                 use_entropy_reward=False,use_value_condition=False):
        """
        Initializes a `BaseAlgo` instance.

        Parameters:
        ----------
        envs : list
            a list of environments that will be run in parallel
        acmodel : torch.Module
            the model
        num_frames_per_proc : int
            the number of frames collected by every process for an update
        discount : float
            the discount for future rewards
        lr : float
            the learning rate for optimizers
        gae_lambda : float
            the lambda coefficient in the GAE formula
            ([Schulman et al., 2015](https://arxiv.org/abs/1506.02438))
        entropy_coef : float
            the weight of the entropy cost in the final objective
        value_loss_coef : float
            the weight of the value loss in the final objective
        max_grad_norm : float
            gradient will be clipped to be at most this value
        recurrence : int
            the number of steps the gradient is propagated back in time
        preprocess_obss : function
            a function that takes observations returned by the environment
            and converts them into the format that the model can handle
        reshape_reward : function
            a function that shapes the reward, takes an
            (observation, action, reward, done) tuple as an input
        """
        self.use_entropy_reward = use_entropy_reward
        self.use_value_condition = use_value_condition

        # Store parameters
        self.envs = envs
        self.env = ParallelEnv(envs)
        self.acmodel = acmodel
        self.device = device
        self.num_frames_per_proc = num_frames_per_proc
        self.discount = discount
        self.lr = lr
        self.gae_lambda = gae_lambda
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.max_grad_norm = max_grad_norm
        self.recurrence = recurrence
        self.preprocess_obss = preprocess_obss or default_preprocess_obss
        self.reshape_reward = reshape_reward

        # Control parameters

        assert self.acmodel.recurrent or self.recurrence == 1
        assert self.num_frames_per_proc % self.recurrence == 0

        # Configure acmodel

        self.acmodel.to(self.device)
        self.acmodel.train()

        self.k = 3
        self.s_ent_stats = TorchRunningMeanStd(shape=[1], device=self.device)

        self.random_encoder = nn.Sequential(
            nn.Conv2d(3, 16, (2, 2)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(16, 32, (2, 2)),
            nn.ReLU(),
            nn.Conv2d(32, 64, (2, 2)),
            nn.ReLU()
        )

        self.random_encoder.to(self.device)
        self.random_encoder.train()

        # Store helpers values

        self.num_procs = len(envs)
        self.num_frames = self.num_frames_per_proc * self.num_procs

        # Initialize experience values

        shape = (self.num_frames_per_proc, self.num_procs)

        self.obs = self.env.reset()
        self.obss = [None]*(shape[0])
        self.agent_loc = [None]*(shape[0])
        if self.acmodel.recurrent:
            self.memory = torch.zeros(shape[1], self.acmodel.memory_size, device=self.device)
            self.memories = torch.zeros(*shape, self.acmodel.memory_size, device=self.device)
        self.mask = torch.ones(shape[1], device=self.device)
        self.masks = torch.zeros(*shape, device=self.device)
        self.actions = torch.zeros(*shape, device=self.device, dtype=torch.int)
        self.values = torch.zeros(*shape, device=self.device)
        self.rewards = torch.zeros(*shape, device=self.device)
        self.advantages = torch.zeros(*shape, device=self.device)
        self.log_probs = torch.zeros(*shape, device=self.device)
        
        self.extr_advantages = torch.zeros(*shape, device=self.device)
        self.extr_values = torch.zeros(*shape, device=self.device)

        # Initialize log values

        self.log_episode_return = torch.zeros(self.num_procs, device=self.device)
        self.log_episode_reshaped_return = torch.zeros(self.num_procs, device=self.device)
        self.log_episode_num_frames = torch.zeros(self.num_procs, device=self.device)

        self.log_done_counter = 0
        self.log_return = [0] * self.num_procs
        self.log_reshaped_return = [0] * self.num_procs
        self.log_num_frames = [0] * self.num_procs

        self.agent_pos_visits = dict()

    def collect_experiences(self):
        """Collects rollouts and computes advantages.

        Runs several environments concurrently. The next actions are computed
        in a batch mode for all environments at the same time. The rollouts
        and advantages from all environments are concatenated together.

        Returns
        -------
        exps : DictList
            Contains actions, rewards, advantages etc as attributes.
            Each attribute, e.g. `exps.reward` has a shape
            (self.num_frames_per_proc * num_envs, ...). k-th block
            of consecutive `self.num_frames_per_proc` frames contains
            data obtained from the k-th environment. Be careful not to mix
            data from different environments!
        logs : dict
            Useful stats about the training process, including the average
            reward, policy loss, value loss, etc.
        """

        for i in range(self.num_frames_per_proc):
            # Do one agent-environment interaction

            preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
            with torch.no_grad():
                if self.acmodel.recurrent:
                    dist, value, memory, extr_value = self.acmodel(preprocessed_obs, self.memory * self.mask.unsqueeze(1), get_extrinsic_value = True)
                else:
                    dist, value, _ , extr_value = self.acmodel(preprocessed_obs,get_extrinsic_value = True)
            action = dist.sample()

            obs, reward, done, _ = self.env.step(action.cpu().numpy())

            # Update experiences values
            for env in self.envs :
                if tuple(env.agent_pos) not in self.agent_pos_visits:
                    self.agent_pos_visits[tuple(env.agent_pos)] = 0
                self.agent_pos_visits[tuple(env.agent_pos)] += 0.0001


            self.obss[i] = self.obs
            self.obs = obs
            if self.acmodel.recurrent:
                self.memories[i] = self.memory
                self.memory = memory
            self.masks[i] = self.mask
            self.mask = 1 - torch.tensor(done, device=self.device, dtype=torch.float)
            self.actions[i] = action
            self.values[i] = value
            self.extr_values[i] = extr_value
            if self.reshape_reward is not None:
                self.rewards[i] = torch.tensor([
                    self.reshape_reward(obs_, action_, reward_, done_)
                    for obs_, action_, reward_, done_ in zip(obs, action, reward, done)
                ], device=self.device)
            else:
                self.rewards[i] = torch.tensor(reward, device=self.device)
            self.log_probs[i] = dist.log_prob(action)

            # Update log values

            self.log_episode_return += torch.tensor(reward, device=self.device, dtype=torch.float)
            self.log_episode_reshaped_return += self.rewards[i]
            self.log_episode_num_frames += torch.ones(self.num_procs, device=self.device)

            for i, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    self.log_return.append(self.log_episode_return[i].item())
                    self.log_reshaped_return.append(self.log_episode_reshaped_return[i].item())
                    self.log_num_frames.append(self.log_episode_num_frames[i].item())

            self.log_episode_return *= self.mask
            self.log_episode_reshaped_return *= self.mask
            self.log_episode_num_frames *= self.mask

        # Add advantage and return to experiences

        preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
        with torch.no_grad():
            if self.acmodel.recurrent:
                _, next_value, _, extr_next_value = self.acmodel(preprocessed_obs, self.memory * self.mask.unsqueeze(1),get_extrinsic_value=True)
            else:
                _, next_value, _, extr_next_value = self.acmodel(preprocessed_obs,get_extrinsic_value=True)

        for i in reversed(range(self.num_frames_per_proc)):
            next_mask = self.masks[i+1] if i < self.num_frames_per_proc - 1 else self.mask
            next_value = self.values[i+1] if i < self.num_frames_per_proc - 1 else next_value
            next_advantage = self.advantages[i+1] if i < self.num_frames_per_proc - 1 else 0

            delta = self.rewards[i] + self.discount * next_value * next_mask - self.values[i]
            self.advantages[i] = delta + self.discount * self.gae_lambda * next_advantage * next_mask
            
            extr_next_value = self.extr_values[i+1] if i < self.num_frames_per_proc - 1 else extr_next_value
            extr_next_advantage = self.extr_advantages[i+1] if i < self.num_frames_per_proc - 1 else 0

            extr_delta = self.rewards[i] + self.discount * extr_next_value * next_mask - self.extr_values[i]
            self.extr_advantages[i] = extr_delta + self.discount * self.gae_lambda * extr_next_advantage * next_mask

        # Define experiences:
        #   the whole experience is the concatenation of the experience
        #   of each process.
        # In comments below:
        #   - T is self.num_frames_per_proc,
        #   - P is self.num_procs,
        #   - D is the dimensionality.

        exps = DictList()
        exps.obs = [self.obss[i][j]
                    for j in range(self.num_procs)
                    for i in range(self.num_frames_per_proc)]
        
        if self.acmodel.recurrent:
            # T x P x D -> P x T x D -> (P * T) x D
            exps.memory = self.memories.transpose(0, 1).reshape(-1, *self.memories.shape[2:])
            # T x P -> P x T -> (P * T) x 1
            exps.mask = self.masks.transpose(0, 1).reshape(-1).unsqueeze(1)
        # for all tensors below, T x P -> P x T -> P * T
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1)
        exps.reward = self.rewards.transpose(0, 1).reshape(-1)
        exps.advantage = self.advantages.transpose(0, 1).reshape(-1)
        exps.returnn = exps.value + exps.advantage
        exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)
        
        exps.extr_value = self.extr_values.transpose(0, 1).reshape(-1)
        exps.extr_advantage = self.extr_advantages.transpose(0, 1).reshape(-1)
        exps.extr_returnn = exps.extr_value + exps.extr_advantage

        # Preprocess experiences

        exps.obs = self.preprocess_obss(exps.obs, device=self.device)

        # Log some values

        keep = max(self.log_done_counter, self.num_procs)

        logs = {
            "return_per_episode": self.log_return[-keep:],
            "reshaped_return_per_episode": self.log_reshaped_return[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.num_frames
        }

        self.log_done_counter = 0
        self.log_return = self.log_return[-self.num_procs:]
        self.log_reshaped_return = self.log_reshaped_return[-self.num_procs:]
        self.log_num_frames = self.log_num_frames[-self.num_procs:]

        return exps, logs

    def soft_update_params(self, net, target_net, tau):
        for param, target_param in zip(net.parameters(), target_net.parameters()):
            target_param.data.copy_(
                tau * param.data + (1 - tau) * target_param.data
            )

    def compute_logits(self, z_a, z_pos):
        """
        Uses logits trick for CURL:
        - compute (B,B) matrix z_a (W z_pos.T)
        - positives are all diagonal elements
        - negatives are all other elements
        - to compute loss use multiclass cross entropy with identity matrix for labels
        """
        Wz = torch.matmul(self.W, z_pos.T)  # (z_dim,B)
        logits = torch.matmul(z_a, Wz)  # (B,B)
        logits = logits - torch.max(logits, 1)[0][:, None]
        return logits

    @abstractmethod
    def update_parameters(self):
        pass

    def compute_state_entropy(self, src_feats, tgt_feats, average_entropy=False):
        with torch.no_grad():
            dists = []
            for idx in range(len(tgt_feats) // 10000 + 1):
                start = idx * 10000
                end = (idx + 1) * 10000
                dist = torch.norm(
                    src_feats[:, None, :] - tgt_feats[None, start:end, :], dim=-1, p=2
                )
                dists.append(dist)

            dists = torch.cat(dists, dim=1)
            knn_dists = 0.0
            if average_entropy:
                for k in range(5):
                    knn_dists += torch.kthvalue(dists, k + 1, dim=1).values
                knn_dists /= 5
            else:
                knn_dists = torch.kthvalue(dists, k=self.k + 1, dim=1).values
            state_entropy = knn_dists
        return state_entropy.unsqueeze(1)

    def compute_value_condition_state_entropy(self, src_feats, tgt_feats, value, average_entropy=False):
        with torch.no_grad():
            dists = []
            state_dists = []
            value_dists = []
            ds = src_feats.size(1)
            for idx in range(len(tgt_feats) // 10000 + 1):
                start = idx * 10000
                end = (idx + 1) * 10000
                state_dist = torch.norm(
                    src_feats[:, None, :] - tgt_feats[None, start:end, :], dim=-1, p=2
                )
                state_dists.append(state_dist)
                value_dist = torch.norm(
                    value[:, None, :] - value[None, start:end, :], dim=-1, p=2
                )
                value_dists.append(value_dist)
                dist = torch.max(torch.cat((state_dist.unsqueeze(-1),value_dist.unsqueeze(-1)),dim=-1),dim=-1)[0]
                dists.append(dist)
                
            dists = torch.cat(dists, dim=1)
            value_dists = torch.cat(value_dists, dim=1)
            state_dists = torch.cat(state_dists, dim=1)
            knn = min(self.k, dists.shape[0])
            eps = 0.0
            if average_entropy:
                for k in range(min(5, dists.shape[0])):
                    eps += torch.kthvalue(dists, k + 1, dim=1).values
                eps /= 5
            else:
                eps = torch.kthvalue(dists, k=knn + 1, dim=1).values
                
            eps = eps.reshape(-1, 1) # (b1, 1)
            value_dists = value_dists < eps
            state_dists = state_dists <= eps
            n_v = torch.sum(value_dists,dim=1,keepdim = True) # (b1,1)
            n_s = torch.sum(state_dists,dim=1,keepdim = True) # (b1,1)
            reward = torch.special.digamma(n_v+1) / ds + torch.log(eps * 2 + 0.00001)
        return reward
    
    def compute_value_condition_structural_entropy(self, src_feats, tgt_feats, value, average_entropy=False):
        sfa, tfa, va = src_feats.cpu().detach().numpy(), tgt_feats.cpu().detach().numpy(), value.cpu().detach().numpy()
        num_vertex = va.shape[0]
        adj_matrix = np.zeros((num_vertex, num_vertex))
        max_max_dist = 0.0
        for vid_i in range(num_vertex):
            for vid_j in range(vid_i, num_vertex):
                max_dist = max_value([abs(np.linalg.norm(sfa[vid_i] - sfa[vid_j], axis=-1)), abs(np.linalg.norm(tfa[vid_i] - tfa[vid_j], axis=-1)), abs(np.linalg.norm(va[vid_i] - va[vid_j], axis=-1))])
                adj_matrix[vid_i, vid_j] = max_dist
                adj_matrix[vid_j, vid_i] = max_dist
                max_max_dist = max(max_max_dist, max_dist)
        for vid_i in range(num_vertex):
            for vid_j in range(num_vertex):
                adj_matrix[vid_i, vid_j] = 1.0 - adj_matrix[vid_i, vid_j] / max_max_dist
        y = PartitionTree(adj_matrix=adj_matrix)
        x = y.build_encoding_tree(k=3)
        sf_level_0, sf_level_1 = list(), list()
        tf_level_0, tf_level_1 = list(), list()
        va_level_0, va_level_1 = list(), list()
        length_level_1 = list()
        tmp_set = set()
        for vid, vertex in y.tree_node.items():
            if vertex.children is None:
                sf_level_0.append(src_feats[vid, :])
                tf_level_0.append(tgt_feats[vid, :])
                va_level_0.append(value[vid, :])
                tmp_set.add(y.tree_node[vertex.parent])
        for vertex in tmp_set:
            tmp_sf, tmp_tf, tmp_va, tmp_ens = list(), list(), list(), list()
            for child in vertex.children:
                tmp_sf.append(src_feats[child, :])
                tmp_tf.append(tgt_feats[child, :])
                tmp_va.append(value[child, :])
                tmp_ens.append(y.node_entropy(child))
            assert sum(tmp_ens) >= 0.0
            if sum(tmp_ens) > 0.0:
                tmp_ens /= sum(tmp_ens)
            else:
                assert len(tmp_ens) == 1
                tmp_ens = [1.0]
            sf1, tf1, va1 = tmp_ens[0] * tmp_sf[0], tmp_ens[0] * tmp_tf[0], tmp_ens[0] * tmp_va[0]
            for i in range(1, len(tmp_ens)):
                sf1 += tmp_ens[i] * tmp_sf[i]
                tf1 += tmp_ens[i] * tmp_tf[i]
                va1 += tmp_ens[i] * tmp_va[i]
            sf_level_1.append(sf1)
            tf_level_1.append(tf1)
            va_level_1.append(va1)
            length_level_1.append(len(tmp_ens))
        sf_level_0, sf_level_1 = [torch.unsqueeze(tensor, dim=0) for tensor in sf_level_0], [torch.unsqueeze(tensor, dim=0) for tensor in sf_level_1]
        tf_level_0, tf_level_1 = [torch.unsqueeze(tensor, dim=0) for tensor in tf_level_0], [torch.unsqueeze(tensor, dim=0) for tensor in tf_level_1]
        va_level_0, va_level_1 = [torch.unsqueeze(tensor, dim=0) for tensor in va_level_0], [torch.unsqueeze(tensor, dim=0) for tensor in va_level_1]
        sf_level_0, sf_level_1 = torch.cat(sf_level_0, dim=0), torch.cat(sf_level_1, dim=0)
        tf_level_0, tf_level_1 = torch.cat(tf_level_0, dim=0), torch.cat(tf_level_1, dim=0)
        va_level_0, va_level_1 = torch.cat(va_level_0, dim=0), torch.cat(va_level_1, dim=0)
        reward_0 = self.compute_value_condition_state_entropy(src_feats, tgt_feats, value, average_entropy)
        reward_1 = self.compute_value_condition_state_entropy(sf_level_1, tf_level_1, va_level_1, average_entropy)
        index = 0
        for i in range(len(reward_1)):
            reward_0[index: index + length_level_1[i]] += (1.0 / length_level_1[i]) * reward_1[i]
            index += length_level_1[i]
        return reward_0

class TorchRunningMeanStd:
    def __init__(self, epsilon=1e-4, shape=(), device=None):
        self.mean = torch.zeros(shape, device=device)
        self.var = torch.ones(shape, device=device)
        self.count = epsilon

    def update(self, x):
        with torch.no_grad():
            batch_mean = torch.mean(x, axis=0)
            batch_var = torch.var(x, axis=0)
            batch_count = x.shape[0]
            self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count
        )

    @property
    def std(self):
        return torch.sqrt(self.var)

def update_mean_var_count_from_moments(
    mean, var, count, batch_mean, batch_var, batch_count
):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta + batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count

