from abc import ABC
from pathlib import Path

import numpy as np
import torch
from torch_geometric.loader import DataLoader

from relnet.agent.pytorch_agent import PyTorchAgent
from relnet.agent.gnn.feature_processing import to_pyg_datapoint_demand, get_feats_dim
from relnet.agent.gnn.demand_reg_gnn import DemandRegressorGNN
from relnet.state.graph_state import get_graph_hash
from relnet.utils.config_utils import get_device_placement


class PredictionAgent(PyTorchAgent, ABC):
    is_deterministic = False
    is_trainable = True

    def setup_predictor(self):
        nf_dim, ef_dim = get_feats_dim(self.demand_rep, self.hyperparams, self.gds_metadata)
        print(f"<<<FEATURE DIMS>>>: {nf_dim}, {ef_dim}")

        if 'num_layers' not in self.hyperparams:
            if self.log_progress: self.logger.info(f"number of layers not provided -- setting depending on max dataset diameter {self.gds_metadata['max_diameter']}")
            self.hyperparams['num_layers'] = self.gds_metadata['max_diameter']
            self.hyperparams['num_layers'] -= self.hyperparams['layers_lt_diam']

        if 'batch_size' not in self.hyperparams:
            bs = self.gds_metadata['max_num_nodes']
            self.hyperparams['batch_size'] = bs
            if self.log_progress: self.logger.info(f"batch size not provided -- setting to num nodes {bs}")

        self.net = DemandRegressorGNN(nf_dim, ef_dim, self.demand_rep, self.gnn_arch, self.hyperparams, self.gds_metadata)

        if self.log_progress: self.logger.info("set up GNN regressor!")
        if get_device_placement() == 'GPU':
            self.net = self.net.cuda()

    def predict(self, g_list, **kwargs):
        if self.restore_model:
            self.setup_predictor()
            self.restore_model_from_checkpoint()

        preds = []

        bs = self.hyperparams['batch_size']
        loader = self.load_data(g_list, bs)

        for data in loader:
            batch_preds = self.net(data, **kwargs)
            preds.extend(batch_preds)

        return torch.stack(preds).flatten()

    def load_data(self, g_list, bs):
        pyg_data = []
        for i, g_hash in enumerate(g_list):
            if self.pyg_cache_dir is None:
                if g_hash in self.pyg_cache:
                    datapoint = self.pyg_cache[g_hash]
                else:
                    graph = self.graph_ds.load_graph_file(g_hash)
                    orig_graph = self.graph_ds.load_graph_file(self.gds_metadata['original_graph_hash'])
                    datapoint = to_pyg_datapoint_demand(graph, self.demand_rep, self.hyperparams, self.gds_metadata, orig_graph=orig_graph)
                    del graph
                    del orig_graph
                    self.pyg_cache[g_hash] = datapoint
            else:
                cache_file = self.pyg_cache_dir / f"{g_hash}-{self.model_identifier_prefix}.pt"
                if g_hash in self.cached_pyg_datapoints:
                    datapoint = torch.load(cache_file)
                else:
                    graph = self.graph_ds.load_graph_file(g_hash)
                    orig_graph = self.graph_ds.load_graph_file(self.gds_metadata['original_graph_hash'])

                    datapoint = to_pyg_datapoint_demand(graph, self.demand_rep, self.hyperparams, self.gds_metadata, orig_graph=orig_graph)
                    del graph
                    del orig_graph
                    torch.save(datapoint, cache_file)
                    self.cached_pyg_datapoints.add(g_hash)

            if get_device_placement() == 'GPU':
                #print(f"converting a datapoint to cuda.")
                datapoint = datapoint.cuda()

            pyg_data.append(datapoint)

        loader = DataLoader(pyg_data, batch_size=bs, shuffle=False)
        return loader


    def setup(self, options, hyperparams):
        super().setup(options, hyperparams)
        if 'edge_coloring' not in hyperparams:
            if hasattr(self, 'edge_coloring'):
                self.hyperparams['edge_coloring'] = self.edge_coloring

        if options['use_pyg_cache_dir']:
            self.pyg_cache_dir = Path(options['file_paths'].pyg_cache_dir)
            self.cached_pyg_datapoints = set()
        else:
            self.pyg_cache_dir = None
            self.pyg_cache = {}

    @staticmethod
    def get_default_hyperparameters():
        gnn_params = {
            "lf_dim": 32,
            "learning_rate": 0.01,
        }
        return gnn_params

    def post_env_setup(self):
        pass

    def finalize(self):
        super().finalize()
        if self.options['use_pyg_cache_dir']:
            for child in self.pyg_cache_dir.glob(f"*{self.model_identifier_prefix}*"):
                if child.is_file():
                    child.unlink(missing_ok=True)



class UniformSummedDemandsRGATAgent(PredictionAgent):
    algorithm_name = "rgat_summed_uniform"

    demand_rep = "summed"
    gnn_arch = "rgat"
    edge_coloring = "uniform"

class UniqueColorEdgeSummedDemandsRGATAgent(PredictionAgent):
    algorithm_name = "rgat_summed_uniqueedge"

    demand_rep = "summed"
    gnn_arch = "rgat"
    edge_coloring = "unique_edge"


class UniformRawDemandsRGATAgent(PredictionAgent):
    algorithm_name = "rgat_raw_uniform"

    demand_rep = "raw"
    gnn_arch = "rgat"
    edge_coloring = "uniform"


class UniqueColorEdgeRawDemandsRGATAgent(PredictionAgent):
    algorithm_name = "rgat_raw_uniqueedge"

    demand_rep = "raw"
    gnn_arch = "rgat"
    edge_coloring = "unique_edge"


### GCN variants -- uniform only.
class UniformSummedDemandsGCNAgent(PredictionAgent):
    algorithm_name = "gcn_summed_uniform"

    demand_rep = "summed"
    gnn_arch = "gcn"
    edge_coloring = "uniform"


class UniformRawDemandsGCNAgent(PredictionAgent):
    algorithm_name = "gcn_raw_uniform"

    demand_rep = "raw"
    gnn_arch = "gcn"
    edge_coloring = "uniform"


### GraphSAGE variants -- uniform only.
class UniformSummedDemandsSAGEAgent(PredictionAgent):
    algorithm_name = "sage_summed_uniform"

    demand_rep = "summed"
    gnn_arch = "sage"
    edge_coloring = "uniform"


class UniformRawDemandsSAGEAgent(PredictionAgent):
    algorithm_name = "sage_raw_uniform"

    demand_rep = "raw"
    gnn_arch = "sage"
    edge_coloring = "uniform"


