import abc

import gtimer as gt
from rlkit.core.rl_algorithm import BaseRLAlgorithm
from rlkit.data_management.replay_buffer import ReplayBuffer
from rlkit.samplers.data_collector import PathCollector
from torch_geometric.data import Data
from torch_geometric.transforms import Compose, Distance, KNNGraph
import numpy as np
import torch
from torch_geometric.transforms import Compose, Distance, KNNGraph, ToDevice
from rlkit.torch import pytorch_util as ptu



class BatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            is_manifold_rep_learning=False,
            k_nn_param=6,
            start_epoch=0, # negative epochs are offline, positive epochs are online
            max_graph_points=20000
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )
        self.max_graph_points = max_graph_points
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self._start_epoch = start_epoch
        self.is_manifold_rep_learning = is_manifold_rep_learning
        self.k_nn_param = k_nn_param


    def train(self):
        """Negative epochs are offline, positive epochs are online"""
        for self.epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self.offline_rl = self.epoch < 0
            self._begin_epoch(self.epoch)
            self._train()
            self._end_epoch(self.epoch)

    def _train(self):
        if self.epoch == 0 and self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            if not self.offline_rl:
                self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        self.eval_data_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True,
        )
        gt.stamp('evaluation sampling')

        for _ in range(self.num_train_loops_per_epoch):
            new_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_expl_steps_per_train_loop,
                discard_incomplete_paths=False,
            )
            gt.stamp('exploration sampling', unique=False)

            if not self.offline_rl:
                self.replay_buffer.add_paths(new_expl_paths)
            gt.stamp('data storing', unique=False)
            obs_dataset = None
            if self.is_manifold_rep_learning:
                obs_np = self.replay_buffer.get_populated_obs()

                if obs_np.shape[0] > self.max_graph_points:
                    random_obs_for_graph = obs_np[np.random.permutation(obs_np.shape[0])[:self.max_graph_points]]
                else:
                    random_obs_for_graph = obs_np

                observations_tensor = torch.from_numpy(random_obs_for_graph)

                obs_data = Data(pos=observations_tensor)
                composed_transforms = Compose([ToDevice(ptu.device), KNNGraph(force_undirected=True,
                                                                              k=self.k_nn_param), Distance()])

                obs_dataset = composed_transforms(obs_data)

                gt.stamp('graph_dataset', unique=False)

            self.training_mode(True)
            for _ in range(self.num_trains_per_train_loop):
                if self.is_manifold_rep_learning:
                    train_data, train_indices = self.replay_buffer.random_batch_with_indices(self.batch_size)
                    all_data_graph = (train_indices, obs_dataset)
                    self.trainer.train_pass_replay_buffer(train_data, all_data_graph=all_data_graph)
                else:
                    train_data = self.replay_buffer.random_batch(self.batch_size)
                    self.trainer.train(train_data)

            gt.stamp('training', unique=False)
            self.training_mode(False)
