import random
from abc import ABC, abstractmethod

import numpy as np
import torch

from relnet.evaluation.eval_utils import score_predictions
from relnet.utils.config_utils import get_logger_instance


class Agent(ABC):
    @abstractmethod
    def train(self, train_g_list, validation_g_list, max_steps, **kwargs):
        pass

    @abstractmethod
    def predict(self, g_list, **kwargs):
        pass

    def predict_and_score(self, g_list, predict_kwargs):
        gts = self.graph_ds.get_gts_for_hashes(g_list)
        return score_predictions(self.predict(g_list, **(predict_kwargs or {})), gts)

    def setup(self, options, hyperparams):
        self.options = options
        self.graph_ds = options['graph_ds']
        self.gds_metadata = self.graph_ds.metadata_dict['global_metadata']

        if 'log_filename' in options:
            self.log_filename = options['log_filename']
        if 'log_progress' in options:
            self.log_progress = options['log_progress']
        else:
            self.log_progress = False
        if self.log_progress:
            self.logger = get_logger_instance(self.log_filename)
        else:
            self.logger = None

        if 'random_seed' in options:
            self.set_random_seeds(options['random_seed'])
        else:
            self.set_random_seeds(42)
        self.hyperparams = hyperparams

    @abstractmethod
    def finalize(self):
        pass

    @staticmethod
    def get_default_hyperparameters():
        return {}

    def set_random_seeds(self, random_seed):
        self.random_seed = random_seed
        self.local_random = random.Random()
        self.local_random.seed(self.random_seed)
        np.random.seed(self.random_seed)
        torch.manual_seed(self.random_seed)
        torch.cuda.manual_seed(self.random_seed)
        torch.cuda.manual_seed_all(self.random_seed)

