import os
import torch
import ecole as ec
import numpy as np
import collections
import random
from agents.agent_model import GNNPolicyItem, GNNPolicyAno, GNNPolicyRL2Branch


class ObservationFunction(ec.observation.NodeBipartite):

    def __init__(self, problem):
        super().__init__()

    def seed(self, seed):
        pass


class Policy():

    def __init__(self, problem):
        self.rng = np.random.RandomState()

        self.device = f"cuda:0"
        self.problem = problem
        policy_name = 'tmdp+ObjLim'

        if problem == 'setcover' or problem == 'indset' or problem == 'cauctions' or problem == 'facilities':
            self.policy = GNNPolicyRL2Branch().to(self.device)
            if policy_name == 'il':
                self.policy.load_state_dict(torch.load(f'agents/model/{self.problem}/il.pkl'))
            elif policy_name == 'mdp':
                self.policy.load_state_dict(torch.load(f'agents/model/{self.problem}/mdp.pkl'))
            elif policy_name == 'tmdp+DFS':
                self.policy.load_state_dict(torch.load(f'agents/model/{self.problem}/tmdp+DFS.pkl'))
            elif policy_name == 'tmdp+ObjLim':
                self.policy.load_state_dict(torch.load(f'agents/model/{self.problem}/tmdp+ObjLim.pkl'))
            else:
                raise Exception(f"Unrecognized GNN policy {policy_name}")
            self.policy.eval()
        else:
            raise ValueError("ML4CO Competition not support")

    def seed(self, seed):
        self.rng = np.random.RandomState(seed)

    def __call__(self, action_set, observation):

        logits = self.policy(torch.from_numpy(observation.row_features.astype(np.float32)).to(self.device),
                        torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(self.device),
                        torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(self.device),
                        torch.from_numpy(observation.column_features.astype(np.float32)).to(self.device))
        action = action_set[logits[action_set.astype(np.int64)].argmax()]
        return action
