from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn

from torch_geometric.utils.subgraph import k_hop_subgraph

from torch.nn import MSELoss
import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer

from torch_geometric.utils.num_nodes import maybe_num_nodes

def numpy_combinations(x, sample_size=128):
    idx = np.stack(np.triu_indices(len(x), k=1), axis=-1)
    random_idxs = idx[np.random.rand(idx.shape[0])[:sample_size]]
    return x[random_idxs]

def k_hop_subgraph_mine(node_idx, num_hops, edge_index,
                   num_nodes=None, edge_attrs=None, flow='source_to_target'):
    r"""Computes the :math:`k`-hop subgraph of :obj:`edge_index` around node
    :attr:`node_idx`.
    It returns (1) the nodes involved in the subgraph, (2) the filtered
    :obj:`edge_index` connectivity, (3) the mapping from node indices in
    :obj:`node_idx` to their new location, and (4) the edge mask indicating
    which edges were preserved.

    Args:
        node_idx (int, list, tuple or :obj:`torch.Tensor`): The central
            node(s).
        num_hops: (int): The number of hops :math:`k`.
        edge_index (LongTensor): The edge indices.
        relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
            :obj:`edge_index` will be relabeled to hold consecutive indices
            starting from zero. (default: :obj:`False`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
        flow (string, optional): The flow direction of :math:`k`-hop
            aggregation (:obj:`"source_to_target"` or
            :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)

    :rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`,
             :class:`BoolTensor`)
    """

    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    assert flow in ['source_to_target', 'target_to_source']
    if flow == 'target_to_source':
        row, col = edge_index
    else:
        col, row = edge_index

    node_mask = row.new_empty(num_nodes, dtype=torch.bool)
    edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
    if isinstance(node_idx, (int, list, tuple)):
        node_idx = torch.tensor([node_idx], device=row.device).flatten()
    else:
        node_idx = node_idx.to(row.device)

    subsets = [node_idx]
    edge_weights = []
    all_edges_list = []
    for i in range(num_hops):
        node_mask.fill_(False)

        node_mask[subsets[-1]] = True

        torch.index_select(node_mask, 0, row, out=edge_mask)
        subsets.append(col[edge_mask])
        new_edge_weights = torch.clone(edge_attrs[edge_mask])

        new_edge_sources = torch.clone(col[edge_mask])
        new_edge_targets = torch.clone(row[edge_mask])
        if len(edge_weights) > 0:
            #Get weights of all previous edges, these re collapsed ones
            prev_edge_weights = edge_weights[-1]

            #Get previous sources and targets
            prev_edges_sources, prev_edges_targets = all_edges_list[-1]

            #Sort by prev target and new src
            new_edge_sources, src_sorted_ix = torch.sort(new_edge_sources, stable=True)
            new_edge_targets = new_edge_targets[src_sorted_ix]
            new_edge_weights = new_edge_weights[src_sorted_ix]

            prev_edges_targets, target_sorted_idx = torch.sort(prev_edges_targets, stable=True)
            prev_edges_sources = prev_edges_sources[target_sorted_idx]
            prev_edge_weights = prev_edge_weights[target_sorted_idx]


            #Count number of times previous targets repeated
            bin_prev_targets = torch.bincount(prev_edges_targets)
            #Count number of times new sources repeat
            bin_new_sources = torch.bincount(new_edge_sources)
            num_prev_targets = bin_prev_targets.shape[0]
            num_new_sources = bin_new_sources.shape[0]

            #Pad if the sizes mismatch
            if num_new_sources > num_prev_targets:
                pad_zeros = torch.zeros(num_new_sources - num_prev_targets).to(ptu.device)
                bin_prev_targets = torch.cat((bin_prev_targets, pad_zeros))
            elif num_prev_targets > num_new_sources:
                pad_zeros = torch.zeros(num_prev_targets - num_new_sources).to(ptu.device)
                bin_new_sources = torch.cat((bin_new_sources, pad_zeros))

            #Repeat based on new bin sizes
            prev_edges_repeats = bin_new_sources[prev_edges_targets].int()
            new_edge_repeats = bin_prev_targets[new_edge_sources].int()

            #Repeat for new edges
            new_edge_attrs_rep = new_edge_weights.repeat_interleave(new_edge_repeats)
            new_edge_target_rep = new_edge_targets.repeat_interleave(new_edge_repeats)

            #Repeat for previous edges
            prev_edge_attrs_rep = prev_edge_weights.repeat_interleave(prev_edges_repeats)
            prev_edge_sources_rep = prev_edges_sources.repeat_interleave(prev_edges_repeats)

            #cat to form new targets, sources and attrs
            sources_new_modified_edges = torch.cat((prev_edges_sources, new_edge_sources, prev_edge_sources_rep))
            targets_new_modified_edges = torch.cat((prev_edges_targets, new_edge_targets, new_edge_target_rep))
            combined_new_edge_weights = torch.add(new_edge_attrs_rep, prev_edge_attrs_rep)
            combined_new_edge_weights = combined_new_edge_weights.reshape((combined_new_edge_weights.shape[0], 1))
            attrs_new_modified_edges = torch.cat((prev_edge_weights, new_edge_weights, combined_new_edge_weights))

            #Eliminate self edges
            is_not_loop_mask = torch.logical_not(torch.eq(sources_new_modified_edges, targets_new_modified_edges))
            sources_new_modified_edges = sources_new_modified_edges[is_not_loop_mask]
            targets_new_modified_edges = targets_new_modified_edges[is_not_loop_mask]
            attrs_new_modified_edges = attrs_new_modified_edges[is_not_loop_mask]

            #Eliminate duplicates and pick the first one
            attrs_new_modified_edges = attrs_new_modified_edges.reshape(-1)

            sorted_edge_attrs, weight_sorted_idxs = torch.sort(attrs_new_modified_edges, dim=0, stable=True,
                                                               descending=False)
            only_all_edges = torch.stack([sources_new_modified_edges[weight_sorted_idxs],
                                          targets_new_modified_edges[weight_sorted_idxs]], dim=1)
            unique_edges, idx, counts = torch.unique(only_all_edges, dim=0, sorted=True, return_inverse=True,
                                                     return_counts=True)
            _, ind_sorted = torch.sort(idx, stable=True)
            cum_sum = counts.cumsum(0)
            cum_sum = torch.cat((torch.tensor([0]).to(ptu.device), cum_sum[:-1]))
            first_indicies = ind_sorted[cum_sum]

            #Choose the first indices for edges and attrs
            sources_new_modified_edges = only_all_edges[first_indicies, 0].long()
            targets_new_modified_edges = only_all_edges[first_indicies, 1].long()
            attrs_new_modified_edges = sorted_edge_attrs[first_indicies]
            attrs_new_modified_edges = attrs_new_modified_edges.reshape((attrs_new_modified_edges.shape[0], 1))

            #Append all values to list
            all_edges_list.append((sources_new_modified_edges, targets_new_modified_edges))
            edge_weights.append(attrs_new_modified_edges)
        else:
            edge_weights.append(new_edge_weights)
            all_edges_list.append((new_edge_sources, new_edge_targets))

    subset = torch.cat(subsets).unique()

    return subset, all_edges_list[-1], edge_weights[-1]


class MRLDDPGTrainer(TorchTrainer):
    """
    Deep Deterministic Policy Gradient
    """
    def __init__(
            self,
            qf,
            target_qf,
            policy,
            target_policy,

            discount=0.99,
            reward_scale=1.0,

            policy_learning_rate=1e-4,
            qf_learning_rate=1e-3,
            qf_weight_decay=0,
            target_hard_update_period=1000,
            tau=1e-2,
            use_soft_update=False,
            qf_criterion=None,
            policy_pre_activation_weight=0.,
            optimizer_class=optim.Adam,

            manifold_rep_weight=0.001,

            min_q_value=-np.inf,
            max_q_value=np.inf,

            num_graph_hops=2,
            batch_size_rep=128,
            num_nodes_rep=64
    ):
        super().__init__()
        if qf_criterion is None:
            qf_criterion = nn.MSELoss()
        self.qf = qf
        self.target_qf = target_qf
        self.policy = policy
        self.target_policy = target_policy

        self.discount = discount
        self.reward_scale = reward_scale

        self.policy_learning_rate = policy_learning_rate
        self.qf_learning_rate = qf_learning_rate
        self.qf_weight_decay = qf_weight_decay
        self.target_hard_update_period = target_hard_update_period
        self.tau = tau
        self.use_soft_update = use_soft_update
        self.qf_criterion = qf_criterion
        self.policy_pre_activation_weight = policy_pre_activation_weight
        self.min_q_value = min_q_value
        self.max_q_value = max_q_value

        self.manifold_rep_weight = manifold_rep_weight

        self.qf_optimizer = optimizer_class(
            self.qf.parameters(),
            lr=self.qf_learning_rate,
        )
        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=self.policy_learning_rate,
        )
        print("Policy parameters:")
        print(self.policy.parameters())

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.num_graph_hops = num_graph_hops

        self.batch_size_rep = batch_size_rep
        self.num_nodes_rep = num_nodes_rep

        self.rep_mse_loss = MSELoss()


    def train_from_torch(self, batch):
        raise Exception("Not implemented for MRLDDPGTrainer, call train_from_torch_pass_replay_buffer")


    def train_from_torch_pass_replay_buffer(self, batch, all_data_graph=None):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']


        #print("Num observations:")
        #print(obs.shape)
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Policy operations.
        """
        if self.policy_pre_activation_weight > 0:
            policy_actions, pre_tanh_value, manifold_rep_output = self.policy(
                obs, return_preactivations=True, return_manifold_rep=True
            )
            pre_activation_policy_loss = (
                (pre_tanh_value**2).sum(dim=1).mean()
            )
            q_output = self.qf(obs, policy_actions)
            raw_policy_loss = - q_output.mean()
            policy_loss = (
                    raw_policy_loss +
                    pre_activation_policy_loss * self.policy_pre_activation_weight
            )
        else:
            knn_graph = all_data_graph[1]
            batch_nodes_for_rep = np.random.randint(0, high=knn_graph.pos.shape[0], size=self.num_nodes_rep)
            batch_nodes_for_rep = torch.from_numpy(batch_nodes_for_rep)


            hop_nodes, hop_edges, hop_attrs = k_hop_subgraph_mine(batch_nodes_for_rep, self.num_graph_hops,
                                                                      knn_graph.edge_index,
                                                                      edge_attrs=knn_graph.edge_attr)


            batch_edge_indices = torch.randperm(hop_edges[0].shape[0])[:self.batch_size_rep]

            batch_edges = torch.cat([hop_edges[0][batch_edge_indices], hop_edges[1][batch_edge_indices]])

            edge_nodes = torch.flatten(batch_edges)

            obs_all = knn_graph.pos[edge_nodes].float()

            _, manifold_reps = self.policy(obs_all, return_manifold_rep=True)

            policy_actions = self.policy(obs, return_manifold_rep=False)

            q_output = self.qf(obs, policy_actions)

            raw_policy_loss = - q_output.mean()

            manifold_reps_by_edge_1 = manifold_reps[0:self.batch_size_rep]
            manifold_reps_by_edge_0 = manifold_reps[self.batch_size_rep:]
            manifold_rep_norm = torch.norm(torch.subtract(manifold_reps_by_edge_1, manifold_reps_by_edge_0), dim=1)
            manifold_rep_norm = manifold_rep_norm.reshape((manifold_rep_norm.shape[0], 1))
            geodesic_dists = hop_attrs[batch_edge_indices]
            mse_loss = self.rep_mse_loss(manifold_rep_norm.float(), geodesic_dists.float())

            policy_loss = raw_policy_loss + self.manifold_rep_weight*mse_loss

        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        # speed up computation by not backpropping these gradients
        next_actions.detach()
        target_q_values = self.target_qf(
            next_obs,
            next_actions,
        )
        q_target = rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()
        q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value)
        q_pred = self.qf(obs, actions)
        bellman_errors = (q_pred - q_target) ** 2
        raw_qf_loss = self.qf_criterion(q_pred, q_target)

        if self.qf_weight_decay > 0:
            reg_loss = self.qf_weight_decay * sum(
                torch.sum(param ** 2)
                for param in self.qf.regularizable_parameters()
            )
            qf_loss = raw_qf_loss + reg_loss
        else:
            qf_loss = raw_qf_loss

        """
        Update Networks
        """

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self._update_target_networks()

        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics['Manifold Rep Loss'] = np.mean(ptu.get_numpy(
                mse_loss
            ))
            self.eval_statistics['Raw Policy Loss'] = np.mean(ptu.get_numpy(
                raw_policy_loss
            ))
            self.eval_statistics['Preactivation Policy Loss'] = (
                    self.eval_statistics['Policy Loss'] -
                    self.eval_statistics['Raw Policy Loss']
            )
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Predictions',
                ptu.get_numpy(q_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors',
                ptu.get_numpy(bellman_errors),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy Action',
                ptu.get_numpy(policy_actions),
            ))
        self._n_train_steps_total += 1

    def _update_target_networks(self):
        if self.use_soft_update:
            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf, self.target_qf, self.tau)
        else:
            if self._n_train_steps_total % self.target_hard_update_period == 0:
                ptu.copy_model_params_from_to(self.qf, self.target_qf)
                ptu.copy_model_params_from_to(self.policy, self.target_policy)

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.policy,
            self.qf,
            self.target_policy,
            self.target_qf,
        ]

    def get_epoch_snapshot(self):
        return dict(
            qf=self.qf,
            target_qf=self.target_qf,
            trained_policy=self.policy,
            target_policy=self.target_policy,
        )
