import torch
import torch_geometric
import torch.nn.functional as F

import copy
import random
import ml_collections
import collections.abc
import numpy as np
import json
import time
import ecole


class NodeBipariteWith43VariableFeatures(ecole.observation.NodeBipartite):
    '''
    Adds (mostly global) features to variable node features.

    Adds 24 extra variable features to each variable on top of standard ecole
    NodeBipartite obs variable features (19), so each variable will have
    43 features in total.

    '''

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def before_reset(self, model):
        super().before_reset(model)

        self.init_dual_bound = None
        self.init_primal_bound = None

    def extract(self, model, done):
        # get the NodeBipartite obs
        obs = super().extract(model, done)

        m = model.as_pyscipopt()

        if self.init_dual_bound is None:
            self.init_dual_bound = m.getDualbound()
            self.init_primal_bound = m.getPrimalbound()

        # dual/primal bound features
        # dual_bound_frac_change = abs(1-(min(self.init_dual_bound, m.getDualbound()) / max(self.init_dual_bound, m.getDualbound())))
        # primal_bound_frac_change = abs(1-(min(self.init_primal_bound, m.getPrimalbound()) / max(self.init_primal_bound, m.getPrimalbound())))
        dual_bound_frac_change = abs(self.init_dual_bound - m.getDualbound()) / self.init_dual_bound
        primal_bound_frac_change = abs(self.init_primal_bound - m.getPrimalbound()) / self.init_primal_bound

        primal_dual_gap = abs(m.getPrimalbound() - m.getDualbound())
        max_dual_bound_frac_change = primal_dual_gap / self.init_dual_bound
        max_primal_bound_frac_change = primal_dual_gap / self.init_primal_bound

        curr_primal_dual_bound_gap_frac = m.getGap()

        # global tree features
        num_leaves_frac = m.getNLeaves() / m.getNNodes()
        num_feasible_leaves_frac = m.getNFeasibleLeaves() / m.getNNodes()
        num_infeasible_leaves_frac = m.getNInfeasibleLeaves() / m.getNNodes()
        # getNSolsFound() raises attribute error for some reason. Not supported by Ecole?
        #         num_feasible_sols_found_frac = m.getNSolsFound() / m.getNNodes() # gives idea for how hard problem is, since harder problems may have more sparse feasible solutions?
        #         num_feasible_best_sols_found_frac = m.getNBestSolsFound() / m.getNSolsFound()
        num_lp_iterations_frac = m.getNNodes() / m.getNLPIterations()

        # focus node features
        num_siblings_frac = m.getNSiblings() / m.getNNodes()
        curr_node = m.getCurrentNode()
        best_node = m.getBestNode()
        if best_node is not None:
            if curr_node.getNumber() == best_node.getNumber():
                is_curr_node_best = 1
            else:
                is_curr_node_best = 0
        else:
            # no best node found yet
            is_curr_node_best = 0
        parent_node = curr_node.getParent()
        if parent_node is not None and best_node is not None:
            if parent_node.getNumber() == best_node.getNumber():
                is_curr_node_parent_best = 1
            else:
                is_curr_node_parent_best = 0
        else:
            # node has no parent node or no best node found yet
            is_curr_node_parent_best = 0
        curr_node_depth = m.getDepth() / m.getNNodes()
        # if m.getDepth() != 0:
        # curr_node_depth = 1 / m.getDepth()
        # else:
        # curr_node_depth = 1
        curr_node_lower_bound_relative_to_init_dual_bound = self.init_dual_bound / curr_node.getLowerbound()
        curr_node_lower_bound_relative_to_curr_dual_bound = m.getDualbound() / curr_node.getLowerbound()
        num_branching_changes, num_constraint_prop_changes, num_prop_changes = curr_node.getNDomchg()
        total_num_changes = num_branching_changes + num_constraint_prop_changes + num_prop_changes
        try:
            branching_changes_frac = num_branching_changes / total_num_changes
        except ZeroDivisionError:
            branching_changes_frac = 0
        try:
            constraint_prop_changes_frac = num_constraint_prop_changes / total_num_changes
        except ZeroDivisionError:
            constraint_prop_changes_frac = 0
        try:
            prop_changes_frac = num_prop_changes / total_num_changes
        except ZeroDivisionError:
            prop_changes_frac = 0
        parent_branching_changes_frac = curr_node.getNParentBranchings() / m.getNNodes()
        best_sibling = m.getBestSibling()
        if best_sibling is None:
            is_best_sibling_none = 1
            is_best_sibling_best_node = 0
        else:
            is_best_sibling_none = 0
            if best_node is not None:
                if best_sibling.getNumber() == best_node.getNumber():
                    is_best_sibling_best_node = 1
                else:
                    is_best_sibling_best_node = 0
            else:
                is_best_sibling_best_node = 0
        if best_sibling is not None:
            best_sibling_lower_bound_relative_to_init_dual_bound = self.init_dual_bound / best_sibling.getLowerbound()
            best_sibling_lower_bound_relative_to_curr_dual_bound = m.getDualbound() / best_sibling.getLowerbound()
            best_sibling_lower_bound_relative_to_curr_node_lower_bound = best_sibling.getLowerbound() / curr_node.getLowerbound()
        else:
            best_sibling_lower_bound_relative_to_init_dual_bound = 0
            best_sibling_lower_bound_relative_to_curr_dual_bound = 0
            best_sibling_lower_bound_relative_to_curr_node_lower_bound = 0

        # add feats to each variable
        feats_to_add = np.array([[dual_bound_frac_change,
                                  primal_bound_frac_change,
                                  max_primal_bound_frac_change,
                                  max_dual_bound_frac_change,
                                  curr_primal_dual_bound_gap_frac,
                                  num_leaves_frac,
                                  num_feasible_leaves_frac,
                                  num_infeasible_leaves_frac,
                                  num_lp_iterations_frac,
                                  num_siblings_frac,
                                  is_curr_node_best,
                                  is_curr_node_parent_best,
                                  curr_node_depth,
                                  curr_node_lower_bound_relative_to_init_dual_bound,
                                  curr_node_lower_bound_relative_to_curr_dual_bound,
                                  branching_changes_frac,  #
                                  constraint_prop_changes_frac,  #
                                  prop_changes_frac,  #
                                  parent_branching_changes_frac,  #
                                  is_best_sibling_none,
                                  is_best_sibling_best_node,
                                  best_sibling_lower_bound_relative_to_init_dual_bound,
                                  best_sibling_lower_bound_relative_to_curr_dual_bound,
                                  best_sibling_lower_bound_relative_to_curr_node_lower_bound] for _ in
                                 range(obs.column_features.shape[0])])

        # # TEMP DEBUGGING
        # illegal_feat_idx_to_val = defaultdict(lambda: [])
        # illegal_found = False
        # for var in feats_to_add:
        # for idx, feat in enumerate(var):
        # if feat < -1 or feat > 1:
        # illegal_found = True
        # illegal_feat_idx_to_val[idx].append(feat)
        # if illegal_found:
        # raise Exception(f'Found illegal feature(s): {illegal_feat_idx_to_val}')

        obs.column_features = np.column_stack((obs.column_features, feats_to_add))

        return obs


class StrongBranchingAgent:
    def __init__(self, pseudo_candidates=False, name='sb'):
        self.name = name
        self.pseudo_candidates = pseudo_candidates
        self.strong_branching_function = ecole.observation.StrongBranchingScores(pseudo_candidates=pseudo_candidates)

    def before_reset(self, model):
        """
        This function will be called at initialization of the environments (before dynamics are reset).
        """
        self.strong_branching_function.before_reset(model)

    def extract(self, model, done):
        return self.strong_branching_function.extract(model, done)

    def action_select(self, action_set, model, done, **kwargs):
        scores = self.extract(model, done)
        action_idx = scores[action_set].argmax()
        return action_set[action_idx], action_idx


class PseudocostBranchingAgent:
    def __init__(self, name='pc'):
        self.name = name
        self.pc_branching_function = ecole.observation.Pseudocosts()

    def before_reset(self, model):
        self.pc_branching_function.before_reset(model)

    def extract(self, model, done):
        return self.pc_branching_function.extract(model, done)

    def action_select(self, action_set, model, done, **kwargs):
        scores = self.extract(model, done)
        action_idx = scores[action_set].argmax()
        return action_set[action_idx], action_idx


class BipartiteGCNNoHeads(torch.nn.Module):
    def __init__(self,
                 device,
                 config=None,
                 emb_size=64,
                 num_rounds=1,
                 aggregator='add',
                 activation=None,
                 mask_nan_logits=True,
                 cons_nfeats=5,
                 edge_nfeats=1,
                 var_nfeats=19,
                 name='gnn'):
        '''
        This is the old implementation of the GNN without any DQN heads. Keeping here
        so can still run old models without difficulty.

        Args:
            config (str, ml_collections.ConfigDict()): If not None, will initialise
                from config dict. Can be either string (path to config.json) or
                ml_collections.ConfigDict object.
            activation (None, 'sigmoid', 'relu', 'leaky_relu', 'elu', 'hard_swish')
        '''
        super().__init__()
        self.device = device

        if config is not None:
            self.init_from_config(config)
        else:
            self.mask_nan_logits = mask_nan_logits
            self.name = name
            self.init_nn_modules(emb_size=emb_size, num_rounds=num_rounds, cons_nfeats=cons_nfeats,
                                 edge_nfeats=edge_nfeats, var_nfeats=var_nfeats, aggregator=aggregator,
                                 activation=activation)

        self.printed_warning = False
        self.to(self.device)

    def init_from_config(self, config):
        if type(config) == str:
            # load from json
            with open(config, 'r') as f:
                json_config = json.load(f)
                config = ml_collections.ConfigDict(json.loads(json_config))
        try:
            self.mask_nan_logits = config.mask_nan_logits
        except AttributeError:
            self.mask_nan_logits = False
        self.name = config.name
        if 'activation' in config.keys():
            pass
        else:
            config.activation = None
        self.init_nn_modules(emb_size=config.emb_size, num_rounds=config.num_rounds, cons_nfeats=config.cons_nfeats,
                             edge_nfeats=config.edge_nfeats, var_nfeats=config.var_nfeats, aggregator=config.aggregator,
                             activation=config.activation)

    def get_networks(self):
        # return {'networks': self}
        return {'network': self}

    def init_nn_modules(self, emb_size=64, num_rounds=1, cons_nfeats=5, edge_nfeats=1, var_nfeats=19, aggregator='add',
                        activation=None):
        self.emb_size = emb_size
        self.num_rounds = num_rounds
        self.cons_nfeats = cons_nfeats
        self.edge_nfeats = edge_nfeats
        self.var_nfeats = var_nfeats
        self.aggregator = aggregator
        self.activation = activation

        # CONSTRAINT EMBEDDING
        self.cons_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(cons_nfeats),
            torch.nn.Linear(cons_nfeats, emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size),
            torch.nn.ReLU(),
        )

        # EDGE EMBEDDING
        self.edge_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(edge_nfeats),
        )

        # VARIABLE EMBEDDING
        self.var_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(var_nfeats),
            torch.nn.Linear(var_nfeats, emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size),
            torch.nn.ReLU(),
        )

        self.conv_v_to_c = BipartiteGraphConvolution(emb_size=emb_size, aggregator=aggregator)
        self.conv_c_to_v = BipartiteGraphConvolution(emb_size=emb_size, aggregator=aggregator)

        output_layers = [
            torch.nn.Linear(emb_size, emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, 1, bias=False),
        ]
        if self.activation is None:
            pass
        elif self.activation == 'sigmoid':
            output_layers.append(torch.nn.Sigmoid())
        elif self.activation == 'relu':
            output_layers.append(torch.nn.ReLU())
        elif self.activation == 'leaky_relu':
            output_layers.append(torch.nn.LeakyReLU())
        elif self.activation == 'elu':
            output_layers.append(torch.nn.ELU())
        elif self.activation == 'hard_swish':
            output_layers.append(torch.nn.Hardswish())
        else:
            raise Exception(f'Unrecognised activation {self.activation}')
        self.output_module = torch.nn.Sequential(*output_layers)

    def _mask_nan_logits(self, logits, mask_val=-1e8):
        logits[logits != logits] = mask_val
        return logits

    def forward(self, *_obs, print_warning=True):

        if len(_obs) > 1:
            # no need to pre-process observation features
            constraint_features, edge_indices, edge_features, variable_features = _obs
            # constraint_features = constraint_features.to(self.device)
            # edge_indices = edge_indices.to(self.device)
            # edge_features = edge_features.to(self.device)
            # variable_features = variable_features.to(self.device)
        else:
            # need to pre-process observation features
            obs = _obs[0]  # unpack
            constraint_features = torch.from_numpy(obs.row_features.astype(np.float32)).to(self.device)
            edge_indices = torch.from_numpy(obs.edge_features.indices.astype(np.int64)).to(self.device)
            edge_features = torch.from_numpy(obs.edge_features.values.astype(np.float32)).view(-1, 1).to(self.device)
            variable_features = torch.from_numpy(obs.column_features.astype(np.float32)).to(self.device)

        reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)

        # First step: linear embedding layers to a common dimension (64)
        constraint_features = self.cons_embedding(constraint_features)
        edge_features = self.edge_embedding(edge_features)
        if variable_features.shape[1] != self.var_nfeats:
            if print_warning:
                if not self.printed_warning:
                    print(
                        f'WARNING: variable_features is shape {variable_features.shape} but var_nfeats is {self.var_nfeats}. Will index out extra features.')
                    self.printed_warning = True
            variable_features = variable_features[:, 0:self.var_nfeats]
        variable_features = self.var_embedding(variable_features)

        # Two half convolutions (message passing round)
        for _ in range(self.num_rounds):
            constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features,
                                                   constraint_features)
            variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)

        # A final MLP on the variable
        output = self.output_module(variable_features).clone().squeeze(
            -1)  # must clone to avoid in place operation gradient error for some reason?

        return output

    def create_config(self):
        '''Returns config dict so that can re-initialise easily.'''
        # create networks dict of self.<attribute> key-value pairs
        network_dict = copy.deepcopy(self.__dict__)

        # remove module references to avoid circular references
        del network_dict['_modules']

        # create config dict
        config = ml_collections.ConfigDict(network_dict)

        return config


class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
    """
    The bipartite graph convolution is already provided by pytorch geometric and we merely need
    to provide the exact form of the messages being passed.
    """

    def __init__(self,
                 aggregator='add',
                 emb_size=64,
                 include_edge_features=False):
        super().__init__(aggregator)

        self.include_edge_features = include_edge_features

        self.feature_module_left = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size)
        )
        if self.include_edge_features:
            self.feature_module_edge = torch.nn.Sequential(
                torch.nn.Linear(1, emb_size, bias=False)
            )
        self.feature_module_right = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size, bias=False)
        )
        self.feature_module_final = torch.nn.Sequential(
            torch.nn.LayerNorm(emb_size),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(emb_size, emb_size)
        )

        self.post_conv_module = torch.nn.Sequential(
            torch.nn.LayerNorm(emb_size)
        )

        # output_layers
        self.output_module = torch.nn.Sequential(
            torch.nn.Linear(2 * emb_size, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
            torch.nn.Linear(emb_size, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
        )

    def forward(self, left_features, edge_indices, edge_features, right_features):
        # def forward(self, left_features, edge_indices, right_features):
        """
        This method sends the messages, computed in the message method.
        """
        # output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]),
        # node_features=(left_features, right_features), edge_features=edge_features)
        # output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]),
        # node_features=(self.feature_module_left(left_features), self.feature_module_right(right_features)))
        if self.include_edge_features:
            edge_feats = self.feature_module_edge(edge_features)
        else:
            edge_feats = None
        output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]),
                                node_features=(
                                self.feature_module_left(left_features), self.feature_module_right(right_features)),
                                edge_features=edge_feats)
        return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1))

    def message(self, node_features_i, node_features_j, edge_features=None):
        # def message(self, node_features_i, node_features_j):
        # output = self.feature_module_final(self.feature_module_left(node_features_i)
        # # + self.feature_module_edge(edge_features)
        # + self.feature_module_right(node_features_j))
        # output = self.feature_module_final(node_features_i + node_features_j)
        if edge_features is not None:
            output = self.feature_module_final(node_features_i + node_features_j + edge_features)
        else:
            output = self.feature_module_final(node_features_i + node_features_j)
        return output


class BipartiteGCN(torch.nn.Module):
    def __init__(self,
                 device,
                 config=None,
                 emb_size=64,
                 num_rounds=1,
                 aggregator='add',
                 activation=None,
                 cons_nfeats=5,
                 edge_nfeats=1,
                 var_nfeats=19,
                 num_heads=1,
                 head_depth=1,
                 linear_weight_init=None,
                 linear_bias_init=None,
                 layernorm_weight_init=None,
                 layernorm_bias_init=None,
                 head_aggregator=None,
                 include_edge_features=False,
                 use_old_heads_implementation=False,
                 profile_time=False,
                 print_warning=True,
                 name='gnn',
                 **kwargs):
        '''
        Args:
            config (str, ml_collections.ConfigDict()): If not None, will initialise
                from config dict. Can be either string (path to config.json) or
                ml_collections.ConfigDict object.
            activation (None, 'sigmoid', 'relu', 'leaky_relu', 'inverse_leaky_relu', 'elu', 'hard_swish',
                'softplus', 'mish', 'softsign')
            num_heads (int): Number of heads (final layers) to use. Will use
                head_aggregator to reduce all heads.
            linear_weight_init (None, 'uniform', 'normal',
                'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal')
            linear_bias_init (None, 'zeros', 'normal')
            layernorm_weight_init (None, 'normal')
            layernorm_bias_init (None, 'zeros', 'normal')
            head_aggregator: How to aggregate output of heads.
                int: Will index head outputs with heads[int]
                'add': Sum heads to get output
                None: Will not aggregate heads
                dict: Specify different head aggregation for training and testing
                    e.g. head_aggregator={'train': None, 'test': 0} to not aggregate
                    heads during training, but at test time only return output
                    of 0th index head.
        '''
        super().__init__()
        self.device = device

        if config is not None:
            self.init_from_config(config)
        else:
            self.name = name
            self.init_nn_modules(emb_size=emb_size,
                                 num_rounds=num_rounds,
                                 cons_nfeats=cons_nfeats,
                                 edge_nfeats=edge_nfeats,
                                 var_nfeats=var_nfeats,
                                 aggregator=aggregator,
                                 activation=activation,
                                 num_heads=num_heads,
                                 head_depth=head_depth,
                                 linear_weight_init=linear_weight_init,
                                 linear_bias_init=linear_bias_init,
                                 layernorm_weight_init=layernorm_weight_init,
                                 layernorm_bias_init=layernorm_bias_init,
                                 head_aggregator=head_aggregator,
                                 include_edge_features=include_edge_features,
                                 use_old_heads_implementation=use_old_heads_implementation)

        self.profile_time = profile_time
        self.printed_warning = False
        self.to(self.device)

    def init_from_config(self, config):
        if type(config) == str:
            # load from json
            with open(config, 'r') as f:
                json_config = json.load(f)
                config = ml_collections.ConfigDict(json.loads(json_config))
        self.name = config.name
        if 'activation' not in config.keys():
            config.activation = None
        if 'num_heads' not in config.keys():
            config.num_heads = 1
        if 'linear_weight_init' not in config.keys():
            config.linear_weight_init = None
        if 'linear_bias_init' not in config.keys():
            config.linear_bias_init = None
        if 'layernorm_weight_init' not in config.keys():
            config.layernorm_weight_init = None
        if 'layernorm_bias_init' not in config.keys():
            config.layernorm_bias_init = None

        if 'head_aggregator' not in config:
            config.head_aggregator = None
        if 'head_depth' not in config:
            config.head_depth = 1

        if 'include_edge_features' not in config:
            config.include_edge_features = False
        if 'use_old_heads_implementation' not in config:
            config.use_old_heads_implementation = False

        self.init_nn_modules(emb_size=config.emb_size,
                             num_rounds=config.num_rounds,
                             cons_nfeats=config.cons_nfeats,
                             edge_nfeats=config.edge_nfeats,
                             var_nfeats=config.var_nfeats,
                             aggregator=config.aggregator,
                             activation=config.activation,
                             num_heads=config.num_heads,
                             head_depth=config.head_depth,
                             linear_weight_init=config.linear_weight_init,
                             linear_bias_init=config.linear_bias_init,
                             layernorm_weight_init=config.layernorm_weight_init,
                             layernorm_bias_init=config.layernorm_bias_init,
                             head_aggregator=config.head_aggregator,
                             include_edge_features=config.include_edge_features,
                             use_old_heads_implementation=config.use_old_heads_implementation)

        if isinstance(self.head_aggregator, ml_collections.config_dict.config_dict.ConfigDict):
            # convert to standard dictionary
            self.head_aggregator = self.head_aggregator.to_dict()

    def get_networks(self):
        return {'networks': self}

    def init_model_parameters(self, init_gnn_params=True, init_heads_params=True):

        def init_params(m):
            if isinstance(m, torch.nn.Linear):
                # weights
                if self.linear_weight_init is None:
                    pass
                elif self.linear_weight_init == 'uniform':
                    torch.nn.init.uniform_(m.weight, a=0.0, b=1.0)
                elif self.linear_weight_init == 'normal':
                    torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
                elif self.linear_weight_init == 'xavier_uniform':
                    torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain(self.activation))
                elif self.linear_weight_init == 'xavier_normal':
                    torch.nn.init.xavier_normal_(m.weight, gain=torch.nn.init.calculate_gain(self.activation))
                elif self.linear_weight_init == 'kaiming_uniform':
                    torch.nn.init.kaiming_uniform_(m.weight, nonlinearity=self.activation)
                elif self.linear_weight_init == 'kaiming_normal':
                    torch.nn.init.kaiming_normal_(m.weight, nonlinearity=self.activation)
                    # torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                else:
                    raise Exception(f'Unrecognised linear_weight_init {self.linear_weight_init}')

                # biases
                if m.bias is not None:
                    if self.linear_bias_init is None:
                        pass
                    elif self.linear_bias_init == 'zeros':
                        torch.nn.init.zeros_(m.bias)
                    elif self.linear_bias_init == 'uniform':
                        torch.nn.init.uniform_(m.bias)
                    elif self.linear_bias_init == 'normal':
                        torch.nn.init.normal_(m.bias)
                    else:
                        raise Exception(f'Unrecognised bias initialisation {self.linear_bias_init}')

            elif isinstance(m, torch.nn.LayerNorm):
                # weights
                if self.layernorm_weight_init is None:
                    pass
                elif self.layernorm_weight_init == 'normal':
                    torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
                else:
                    raise Exception(f'Unrecognised layernorm_weight_init {self.layernorm_weight_init}')

                # biases
                if self.layernorm_bias_init is None:
                    pass
                elif self.layernorm_bias_init == 'zeros':
                    torch.nn.init.zeros_(m.bias)
                elif self.layernorm_bias_init == 'normal':
                    torch.nn.init.normal_(m.bias)
                else:
                    raise Exception(f'Unrecognised layernorm_bias_init {self.layernorm_bias_init}')

        if init_gnn_params:
            # init base GNN params
            self.apply(init_params)

        if init_heads_params:
            # init head output params
            for h in self.heads_module:
                h.apply(init_params)

    def init_nn_modules(self,
                        emb_size=64,
                        num_rounds=1,
                        cons_nfeats=5,
                        edge_nfeats=1,
                        var_nfeats=19,
                        aggregator='add',
                        activation=None,
                        num_heads=1,
                        head_depth=1,
                        linear_weight_init=None,
                        linear_bias_init=None,
                        layernorm_weight_init=None,
                        layernorm_bias_init=None,
                        head_aggregator='add',
                        include_edge_features=False,
                        use_old_heads_implementation=False):
        self.emb_size = emb_size
        self.num_rounds = num_rounds
        self.cons_nfeats = cons_nfeats
        self.edge_nfeats = edge_nfeats
        self.var_nfeats = var_nfeats
        self.aggregator = aggregator
        self.activation = activation
        self.num_heads = num_heads
        self.head_depth = head_depth
        self.linear_weight_init = linear_weight_init
        self.linear_bias_init = linear_bias_init
        self.layernorm_weight_init = layernorm_weight_init
        self.layernorm_bias_init = layernorm_bias_init
        self.head_aggregator = head_aggregator
        self.include_edge_features = include_edge_features
        self.use_old_heads_implementation = use_old_heads_implementation

        # CONSTRAINT EMBEDDING
        self.cons_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(cons_nfeats),
            torch.nn.Linear(cons_nfeats, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
            torch.nn.Linear(emb_size, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
        )

        # EDGE EMBEDDING
        if self.include_edge_features:
            self.edge_embedding = torch.nn.Sequential(
                torch.nn.LayerNorm(edge_nfeats),
            )

        # VARIABLE EMBEDDING
        self.var_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(var_nfeats),
            torch.nn.Linear(var_nfeats, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
            torch.nn.Linear(emb_size, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
        )

        self.conv_v_to_c = BipartiteGraphConvolution(emb_size=emb_size, aggregator=aggregator,
                                                     include_edge_features=self.include_edge_features)
        self.conv_c_to_v = BipartiteGraphConvolution(emb_size=emb_size, aggregator=aggregator,
                                                     include_edge_features=self.include_edge_features)

        # HEADS
        if self.use_old_heads_implementation:
            # OLD
            self.heads_module = torch.nn.ModuleList([
                torch.nn.Sequential(
                    torch.nn.Linear(emb_size, emb_size),
                    torch.nn.LeakyReLU(),
                    torch.nn.Linear(emb_size, 1, bias=True)
                )
                for _ in range(self.head_depth)
                for _ in range(self.num_heads)
            ])
        else:
            # NEW
            heads = []
            for _ in range(self.num_heads):
                head = []
                for _ in range(self.head_depth):
                    head.append(torch.nn.Linear(emb_size, emb_size))
                    head.append(torch.nn.LeakyReLU())
                head.append(torch.nn.Linear(emb_size, 1, bias=True))
                heads.append(torch.nn.Sequential(*head))
            self.heads_module = torch.nn.ModuleList(heads)

        if self.activation is None:
            self.activation_module = None
        elif self.activation == 'sigmoid':
            self.activation_module = torch.nn.Sigmoid()
        elif self.activation == 'relu':
            self.activation_module = torch.nn.ReLU()
        elif self.activation == 'leaky_relu' or self.activation == 'inverse_leaky_relu':
            self.activation_module = torch.nn.LeakyReLU()
        elif self.activation == 'elu':
            self.activation_module = torch.nn.ELU()
        elif self.activation == 'hard_swish':
            self.activation_module = torch.nn.Hardswish()
        elif self.activation == 'softplus':
            self.activation_module = torch.nn.Softplus()
        elif self.activation == 'mish':
            self.activation_module = torch.nn.Mish()
        elif self.activation == 'softsign':
            self.activation_module = torch.nn.Softsign()
        else:
            raise Exception(f'Unrecognised activation {self.activation}')

        self.init_model_parameters()

    def forward(self, *_obs, print_warning=True):
        '''Returns output of each head.'''
        forward_start = time.time()
        if len(_obs) > 1:
            # no need to pre-process observation features
            # if len(_obs) == 4:
            # # old obs where had pointless edge features
            # constraint_features, edge_indices, _, variable_features = _obs
            # else:
            # constraint_features, edge_indices, variable_features = _obs
            constraint_features, edge_indices, edge_features, variable_features = _obs

            # convert to tensors if needed
            if isinstance(constraint_features, np.ndarray):
                constraint_features = torch.from_numpy(constraint_features).to(self.device)
            if isinstance(edge_indices, np.ndarray):
                edge_indices = torch.LongTensor(edge_indices).to(self.device)
            if isinstance(edge_features, np.ndarray):
                edge_features = torch.from_numpy(edge_features).to(self.device).unsqueeze(1)
            if isinstance(variable_features, np.ndarray):
                variable_features = torch.from_numpy(variable_features).to(self.device)

        else:
            # need to pre-process observation features
            obs = _obs[0]  # unpack
            start = time.time()
            constraint_features = torch.from_numpy(obs.row_features.astype(np.float32)).to(self.device)
            # edge_indices = torch.from_numpy(obs.edge_features.indices.astype(np.int16)).to(self.device)
            edge_indices = torch.LongTensor(obs.edge_features.indices.astype(np.int16)).to(self.device)
            edge_features = torch.from_numpy(obs.edge_features.values.astype(np.float32)).view(-1, 1).to(self.device)
            variable_features = torch.from_numpy(obs.column_features.astype(np.float32)).to(self.device)
            if self.profile_time:
                print(f'var feat: {variable_features[0][0]}')
                t = time.time() - start
                print(f'to_t: {t * 1e3:.3f} ms')

        reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)

        # First step: linear embedding layers to a common dimension (64)
        first_step_start = time.time()
        constraint_features = self.cons_embedding(constraint_features)
        if self.include_edge_features:
            edge_features = self.edge_embedding(edge_features)
        if variable_features.shape[1] != self.var_nfeats:
            if print_warning:
                if not self.printed_warning:
                    ans = None
                    while ans not in {'y', 'n'}:
                        ans = input(
                            f'WARNING: variable_features is shape {variable_features.shape} but var_nfeats is {self.var_nfeats}. Will index out extra features. Continue? (y/n): ')
                    if ans == 'y':
                        pass
                    else:
                        raise Exception('User stopped programme.')
                self.printed_warning = True
            variable_features = variable_features[:, 0:self.var_nfeats]
        variable_features = self.var_embedding(variable_features)
        if self.profile_time:
            print(variable_features[0][0])
            first_step_t = time.time() - first_step_start
            print(f'first_step_t: {first_step_t * 1e3:.3f} ms')

        # Two half convolutions (message passing round)
        conv_start = time.time()
        for _ in range(self.num_rounds):
            constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features,
                                                   constraint_features)
            variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)
            # constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, constraint_features)
            # variable_features = self.conv_c_to_v(constraint_features, edge_indices, variable_features)
        if self.profile_time:
            print(f'{variable_features[0][0]}')
            conv_t = time.time() - conv_start
            print(f'conv_t: {conv_t * 1e3:.3f} ms')

        # get output for each head
        head_output_start = time.time()
        head_output = [self.heads_module[head](variable_features).squeeze(-1) for head in range(self.num_heads)]
        if self.profile_time:
            print(f'{head_output[0][0]}')
            head_output_t = time.time() - head_output_start
            print(f'head_output_t: {head_output_t * 1e3:.3f} ms')
        # print(f'head outputs: {head_output}')

        # get head aggregator
        head_output_agg_start = time.time()
        if isinstance(self.head_aggregator, dict):
            if self.training:
                head_aggregator = self.head_aggregator['train']
            else:
                head_aggregator = self.head_aggregator['test']
        else:
            head_aggregator = self.head_aggregator

        # check if should aggregate head outputs
        if head_aggregator is None:
            # do not aggregate heads
            pass
        else:
            # aggregate head outputs
            if head_aggregator == 'add':
                head_output = [torch.stack(head_output, dim=0).sum(dim=0)]
            elif head_aggregator == 'mean':
                head_output = [torch.stack(head_output, dim=0).mean(dim=0)]
            elif isinstance(head_aggregator, int):
                head_output = [head_output[head_aggregator]]
            else:
                raise Exception(f'Unrecognised head_aggregator {head_aggregator}')

        if self.profile_time:
            print(f'{head_output[0][0]}')
            head_output_agg_t = time.time() - head_output_agg_start
            print(f'head_output_agg_t: {head_output_agg_t * 1e3:.3f} ms')

        # activation
        activation_start = time.time()
        if self.activation_module is not None:
            head_output = [self.activation_module(head) for head in head_output]
            if self.activation == 'inverse_leaky_relu':
                # invert
                head_output = [-1 * head for head in head_output]
        if self.profile_time:
            print(f'{head_output[0][0]}')
            activation_t = time.time() - activation_start
            print(f'activation_t: {activation_t * 1e3:.3f}')
        # print(f'head outputs after activation: {head_output}')

        # # activation
        # if self.activation_module is not None:
        # head_output = self.activation_module(head_output)

        if self.profile_time:
            print(f'{head_output[0][0]}')
            forward_t = time.time() - forward_start
            print(f'>>> total forward time: {forward_t * 1e3:.3f} ms <<<')

        return head_output

    def create_config(self):
        '''Returns config dict so that can re-initialise easily.'''
        # create networks dict of self.<attribute> key-value pairs
        network_dict = copy.deepcopy(self.__dict__)

        # remove module references to avoid circular references
        del network_dict['_modules']

        # create config dict
        config = ml_collections.ConfigDict(network_dict)

        return config


class Agent:
    def __init__(self,
                 network=None,
                 config=None,
                 device=None,
                 head_aggregator='add',
                 network_name='networks',
                 print_forward_dim_warning=True,
                 name='agent'):
        '''
        Use this class for loading a pre-trained networks and doing test-time inference
        with it. Network can have been trained with any method.

        To select an action, passes observation to networks and selects action
        with highest logit output.
        '''
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device

        self.print_forward_dim_warning = print_forward_dim_warning

        if config is not None:
            self.init_from_config(config)
        else:
            if network is None:
                raise Exception('Must provide networks.')
            self.network = network.to(self.device)
            self.head_aggregator = head_aggregator
            self.network_name = network_name
        self.name = name

    def init_from_config(self, config):
        if type(config) == str:
            # load from json
            with open(config, 'r') as f:
                json_config = json.load(f)
                config = ml_collections.ConfigDict(json.loads(json_config))

        # find networks in config
        if 'networks' in config.keys():
            self.network_name, net_config = 'networks', config.network
        elif 'policy_network' in config.keys():
            self.network_name, net_config = 'policy', config.policy_network
        elif 'value_network' in config.keys():
            self.network_name, net_config = 'value_network', config.value_network
        elif 'value_network_1' in config.keys():
            self.network_name, net_config = 'value_network_1', config.value_network_1
        elif 'actor_network' in config.keys():
            self.network_name, net_config = 'actor_network', config.actor_network
        else:
            # config is networks config (is case for e.g. supervised learning, where didn't train with specific agent)
            self.network_name, net_config = 'networks', config

        # TEMPORARY: For where have different networks implementations
        if 'num_heads' in net_config.keys():
            NET = BipartiteGCN
        else:
            NET = BipartiteGCNNoHeads

        self.network = NET(device=self.device,
                           config=net_config)
        self.network.to(self.device)

        if 'head_aggregator' in config:
            self.head_aggregator = config.head_aggregator
        else:
            self.head_aggregator = None

    def get_networks(self):
        return {self.network_name: self.network}

    def forward(self, obs, **kwargs):
        '''Useful for compatability with some DQN custom test scripts.'''
        if type(obs) == tuple:
            return self.network(*obs, print_warning=self.print_forward_dim_warning)
        else:
            return self.network(obs, print_warning=self.print_forward_dim_warning)

    def before_reset(self, model):
        pass

    def train(self):
        self.network.train()

    def eval(self):
        self.network.eval()

    def _mask_nan_logits(self, logits, mask_val=-1e8):
        if type(logits) == list:
            for head in range(len(logits)):
                logits[head][logits[head] != logits[head]] = mask_val
        else:
            logits[logits != logits] = mask_val
        return logits

    def parameters(self):
        return self.network.parameters()

    def action_select(self, **kwargs):
        # check args valid
        if 'state' not in kwargs:
            if 'action_set' not in kwargs and 'obs' not in kwargs:
                raise Exception('Must provide either state or action_set and obs as kwargs.')

        # process observation
        if 'state' in kwargs:
            self.obs = (kwargs['state'].constraint_features, kwargs['state'].edge_index, kwargs['state'].edge_attr,
                        kwargs['state'].variable_features)
            self.action_set = torch.as_tensor(kwargs['state'].candidates)
        else:
            # unpack
            self.action_set, self.obs = kwargs['action_set'], kwargs['obs']
            if isinstance(self.action_set, np.ndarray):
                self.action_set = torch.as_tensor(self.action_set)

        # forward pass through NN
        self.logits = self.forward(self.obs)

        # filter invalid actions
        if type(self.logits) == list:
            # Q-heads DQN, need to aggregate to get values for each action
            self.preds = [self.logits[head][self.action_set] for head in range(len(self.logits))]

            # get head aggregator
            if isinstance(self.head_aggregator, dict):
                if self.network.training:
                    head_aggregator = self.head_aggregator['train']
                else:
                    head_aggregator = self.head_aggregator['test']
            else:
                head_aggregator = self.head_aggregator

            if head_aggregator is None:
                self.preds = torch.stack(self.preds).squeeze(0)
            elif head_aggregator == 'add':
                self.preds = torch.stack(self.preds, dim=0).sum(dim=0)
            elif isinstance(head_aggregator, int):
                self.preds = self.preds[head_aggregator]
            else:
                raise Exception(f'Unrecognised head_aggregator {self.head_aggregator}')

        else:
            # no heads
            self.preds = self.logits[self.action_set]

        # get agent action
        if 'state' in kwargs:
            # batch of observations
            self.preds = self.preds.split_with_sizes(tuple(kwargs['state'].num_candidates))
            self.action_set = kwargs['state'].raw_candidates.split_with_sizes(tuple(kwargs['state'].num_candidates))

            # exploit
            self.action_idx = torch.stack([q.argmax() for q in self.preds])
            self.action = torch.stack([_action_set[idx] for _action_set, idx in zip(self.action_set, self.action_idx)])
        else:
            # single observation, exploit
            self.action_idx = torch.argmax(self.preds)
            self.action = self.action_set[self.action_idx.item()]

        return self.action, self.action_idx


