import os
import torch
import ecole as ec
import numpy as np
import collections
import random
from agents.agent_model import GNNPolicyItem, GNNPolicyAno, GNNPolicy


class ObservationFunction(ec.observation.NodeBipartite):

    def __init__(self, problem):
        super().__init__()

    def seed(self, seed):
        pass


class Policy():

    def __init__(self, problem):
        self.rng = np.random.RandomState()

        self.device = f"cuda:0"
        self.problem = problem

        if problem == 'item_placement':
            self.policy = GNNPolicyItem().to(self.device)
            policy0 = GNNPolicyItem().to(self.device)
            policy1 = GNNPolicyItem().to(self.device)
            policy2 = GNNPolicyItem().to(self.device)
            policy0.load_state_dict(torch.load('agents/model/itemplacement/item0_KIDA.pkl'))
            policy1.load_state_dict(torch.load('agents/model/itemplacement/item1_KIDA.pkl'))
            policy2.load_state_dict(torch.load('agents/model/itemplacement/item2_KIDA.pkl'))
            models = [policy0, policy1, policy2]
            worker_state_dict = [x.state_dict() for x in models]
            weight_keys = list(worker_state_dict[0].keys())
            fed_state_dict = collections.OrderedDict()
            for key in weight_keys:
                key_sum = 0
                for i in range(len(models)):
                    key_sum = key_sum + worker_state_dict[i][key]
                fed_state_dict[key] = key_sum / len(models)
            self.policy.load_state_dict(fed_state_dict)
            self.policy.eval()

        elif problem == 'setcover':
            self.policy = GNNPolicy().to(self.device)
            self.policy.load_state_dict(torch.load('agents/model/setcover/setcover_KIDA.pkl'))  # best model params
            self.policy.eval()

        elif problem == 'indset':
            self.policy = GNNPolicy().to(self.device)
            self.policy.load_state_dict(torch.load('agents/model/indset/indset_KIDA.pkl'))
            self.policy.eval()

        elif problem == 'cauctions':
            self.policy = GNNPolicy().to(self.device)
            self.policy.load_state_dict(torch.load('agents/model/cauctions/cauctions_KIDA.pkl'))
            self.policy.eval()
        
        elif problem == 'facilities':
            self.policy = GNNPolicy().to(self.device)
            self.policy.load_state_dict(torch.load('agents/model/facilities/facilities_KIDA.pkl'))
            self.policy.eval()

    def seed(self, seed):
        self.rng = np.random.RandomState(seed)
    
    def __call__(self, action_set, observation):
        variable_features = observation.column_features
        variable_features = np.delete(variable_features, 14, axis=1)
        variable_features = np.delete(variable_features, 13, axis=1)

        observation = (torch.from_numpy(observation.row_features.astype(np.float32)).to(self.device),
                        torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(self.device),
                        torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(self.device),
                        torch.from_numpy(variable_features.astype(np.float32)).to(self.device))

        logits = self.policy(*observation)
        action = action_set[logits[action_set.astype(np.int64)].argmax()]
        return action

