import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.nn import RGCNConv
from torch_geometric.data import Data
from collections import defaultdict


class Net1(nn.Module):
    def __init__(self,feature_dim,hidden_dim,output_dim,num_relation):
        super(Net1,self).__init__()
        self.conv1 = RGCNConv(feature_dim,hidden_dim,num_relation+1)
        self.conv1.weight.data.normal_(0,0.1)
        self.conv2 = RGCNConv(hidden_dim,output_dim,num_relation+1)
        self.conv2.weight.data.normal_(0,0.1)
        self.fc1 = nn.Linear(output_dim*2, 100)
        self.fc1.weight.data.normal_(0, 0.1)
        self.out = nn.Linear(100, 1)
        self.out.weight.data.normal_(0, 0.1)

    def forward(self, g:Data, cluster_ind):
        x, edge_index, edge_type = g.x, g.edge_index, g.edge_type
        graph_emb = F.relu(self.conv1(x, edge_index, edge_type))
        graph_emb = self.conv2(graph_emb, edge_index, edge_type)
        cluster_emb = graph_emb[cluster_ind]
        head_emb = torch.cat((torch.mean(graph_emb,dim=0),torch.mean(cluster_emb,dim=0)),dim=0)
        result = F.relu(self.fc1(head_emb))
        return self.out(result), head_emb, graph_emb
    

class DQN1(object):
    def __init__(self, feature_dim, hidden_dim, output_dim, num_relation, MEMORY_CAPACITY,
                 TARGET_REPLACE_ITER=10, GAMMA=0.99, EPSILON=0.95, LR=0.01, BATCH_SIZE=8):
        self.eval_net, self.target_net = Net1(feature_dim, hidden_dim, output_dim, num_relation),\
                                         Net1(feature_dim, hidden_dim, output_dim, num_relation)
        self.learn_step_counter = 0

        self.pos_memory_counter = 0
        self.neg_memory_counter = 0
        self.pos_memory = np.zeros((MEMORY_CAPACITY, 5))# graph_ind, cluster_ind, reward, next_graph_ind, next_cluster_ind
        self.neg_memory = np.zeros((MEMORY_CAPACITY, 5))

        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
        self.loss_func = nn.MSELoss()
        self.MEMORY_CAPACITY = MEMORY_CAPACITY
        self.TARGET_REPLACE_ITER = TARGET_REPLACE_ITER
        self.GAMMA = GAMMA
        self.EPSILON = EPSILON
        self.BATCH_SIZE = BATCH_SIZE
        self.graph_pool = defaultdict()
        self.graph_ind = 0
        self.cluster_pool = defaultdict()
        self.cluster_ind = 0

    def get_q_value(self, g:Data, cluster_ind):
        return self.eval_net(g, cluster_ind)

    def get_q_value_next_state(self, g:Data, cluster_ind):
        return self.target_net(g, cluster_ind)
    
    def store_graph(self, g:Data):
        self.graph_pool[self.graph_ind] = g
        self.graph_ind += 1

    def store_cluster(self, cluster_index):
        self.cluster_pool[self.cluster_ind] = cluster_index
        self.cluster_ind += 1

    def store_pos_transition(self, r):
        s, a, s_, a_ = self.graph_ind-2, self.cluster_ind-2, self.graph_ind-1, self.cluster_ind-1
        transition = np.hstack((s, a, r, s_, a_))
        index = self.pos_memory_counter % self.MEMORY_CAPACITY
        self.pos_memory[index, :] = transition
        self.pos_memory_counter += 1

    def store_neg_transition(self, r):
        s, a, s_, a_ = self.graph_ind-2, self.cluster_ind-2, self.graph_ind-1, self.cluster_ind-1
        transition = np.hstack((s, a, r, s_, a_))
        index = self.neg_memory_counter % self.MEMORY_CAPACITY
        self.neg_memory[index, :] = transition
        self.neg_memory_counter += 1
    
    def store_transition(self, r):
        if(r > 0):
            self.store_pos_transition(r)
        else:
            self.store_neg_transition(r)

    def learn(self):
        if self.learn_step_counter % self.TARGET_REPLACE_ITER == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        sample_index = np.random.randint(0, self.MEMORY_CAPACITY)
        samples = [self.pos_memory[sample_index],self.neg_memory[sample_index]]
        for b_memory in samples:
            b_s, b_a, b_r, b_s_, b_a_ = b_memory[0], b_memory[1], b_memory[2], b_memory[3], b_memory[4]
            q_eval = self.eval_net(self.graph_pool[b_s],self.cluster_pool[b_a])[0]
            q_next = self.target_net(self.graph_pool[b_s_],self.cluster_pool[b_a_])[0].detach()
            q_target = b_r + self.GAMMA * q_next
            loss = self.loss_func(q_eval, q_target)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()



class Net2(nn.Module):
    def __init__(self, input_dim, num_relation):
        super(Net2, self).__init__()
        self.fc1 = nn.Linear(input_dim, 100)
        self.fc1.weight.data.normal_(0, 0.1)
        self.out = nn.Linear(100, num_relation)
        self.out.weight.data.normal_(0, 0.1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        action_value = self.out(x)
        return action_value

class DQN2(object):
    def __init__(self, input_dim, num_rel, MEMORY_CAPACITY,
                 TARGET_REPLACE_ITER=10, GAMMA=0.99, EPSILON=0.95, LR=0.01,
                 BATCH_SIZE=8, EPS_START=0.9, EPS_END=0.05, EPS_DECAY=200):
        self.eval_net, self.target_net = Net2(input_dim, num_rel), Net2(input_dim, num_rel)
        self.learn_step_counter = 0


        self.pos_memory_counter = 0
        self.pos_memory = np.zeros((MEMORY_CAPACITY, input_dim*2+2))# input_dim * 2 + operation + reward
        self.neg_memory_counter = 0
        self.neg_memory = np.zeros((MEMORY_CAPACITY, input_dim*2+2))


        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
        self.loss_func = nn.MSELoss()
        self.N_ACTIONS = num_rel
        self.STATE_DIM = input_dim
        self.MEMORY_CAPACITY = MEMORY_CAPACITY
        self.TARGET_REPLACE_ITER = TARGET_REPLACE_ITER
        self.GAMMA = GAMMA
        self.EPSILON = EPSILON
        self.BATCH_SIZE = BATCH_SIZE
        self.EPS_START = EPS_START
        self.EPS_END = EPS_END
        self.EPS_DECAY = EPS_DECAY

    def choose_next_action(self, x):
        x = torch.unsqueeze(torch.FloatTensor(x), 0)
        action_value = self.target_net.forward(x)
        action = torch.max(action_value, 1)[1].data.numpy()
        action = action[0]
        return action

    def choose_action(self, x, eps_threshold):
        x = torch.unsqueeze(torch.FloatTensor(x), 0)
        if np.random.uniform() > eps_threshold:
            action_value = self.eval_net.forward(x)
            action = torch.max(action_value, 1)[1].data.numpy()
            action = action[0]
        else:
            action = np.random.randint(0, self.N_ACTIONS)
        return action
    
    def store_pos_transition(self, s, a, r, s_):
        transition = np.hstack((s, a, r, s_))
        index = self.pos_memory_counter % self.MEMORY_CAPACITY
        self.pos_memory[index, :] = transition
        self.pos_memory_counter += 1

    def store_neg_transition(self, s, a, r, s_):
        transition = np.hstack((s, a, r, s_))
        index = self.neg_memory_counter % self.MEMORY_CAPACITY
        self.neg_memory[index, :] = transition
        self.neg_memory_counter += 1

    def store_transition(self, s, a, r, s_):
        if(r > 0):
            self.store_pos_transition(s, a, r, s_)
        else:
            self.store_neg_transition(s, a, r, s_)

    def learn(self):
        if self.learn_step_counter % self.TARGET_REPLACE_ITER == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter += 1
        sample_index = np.random.choice(self.MEMORY_CAPACITY, self.BATCH_SIZE)

        pos_sample = self.pos_memory[sample_index, :]
        neg_sample = self.neg_memory[sample_index, :]
        b_memory = np.concatenate((pos_sample,neg_sample), axis=0)

        b_s = torch.FloatTensor(b_memory[:, :self.STATE_DIM])
        b_a = torch.LongTensor(b_memory[:, self.STATE_DIM:self.STATE_DIM + 1])
        b_r = torch.FloatTensor(b_memory[:, self.STATE_DIM + 1 : self.STATE_DIM +2])
        b_s_ = torch.FloatTensor(b_memory[:, -self.STATE_DIM:])
        q_eval = self.eval_net(b_s).gather(1, b_a)
        q_next = self.target_net(b_s_).detach()
        q_target = b_r + self.GAMMA * q_next.max(1)[0].view(self.BATCH_SIZE * 2, 1)
        loss = self.loss_func(q_eval, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


class Net3(nn.Module):
    def __init__(self, num_rel, op_dim, input_dim, node_dim):
        super(Net3,self).__init__()
        #操作嵌入层
        self.op_emb_layer = nn.Embedding(num_rel, op_dim)
        #节点运算模块
        self.fc1 = nn.Linear(input_dim+op_dim+node_dim, 100)
        self.fc1.weight.data.normal_(0, 0.1)
        self.out = nn.Linear(100, 1)
        self.out.weight.data.normal_(0, 0.1)
        
    def forward(self, head_emb, op, node_emb):
        op_emb = self.op_emb_layer(op)
        #在batch运算时会多一维
        if(op_emb.dim() == 3):
            op_emb = op_emb.squeeze()
            x = torch.cat((head_emb,op_emb),dim=1)
            x = torch.cat((x,node_emb),dim=1)
        else:
            x = torch.cat((head_emb,op_emb),dim=0)
            x = torch.cat((x,node_emb),dim=0)
        x = self.fc1(x)
        x = F.relu(x)
        action_value = self.out(x)
        return action_value

class DQN3(object):
    def __init__(self, num_rel, op_dim, input_dim, node_dim, MEMORY_CAPACITY,
                 TARGET_REPLACE_ITER=10, GAMMA=0.99, EPSILON=0.95, LR=0.01, BATCH_SIZE=8):
        self.eval_net, self.target_net = Net3(num_rel, op_dim, input_dim, node_dim),\
                                         Net3(num_rel, op_dim, input_dim, node_dim)
                                         
        self.learn_step_counter = 0
        self.neg_memory_counter = 0
        self.neg_memory = np.zeros((MEMORY_CAPACITY, input_dim*2 +3 +node_dim*2))# global+cluster+operation+node
        self.pos_memory_counter = 0
        self.pos_memory = np.zeros((MEMORY_CAPACITY, input_dim*2 +3 +node_dim*2))


        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
        self.loss_func = nn.MSELoss()
        self.MEMORY_CAPACITY = MEMORY_CAPACITY
        self.TARGET_REPLACE_ITER = TARGET_REPLACE_ITER
        self.GAMMA = GAMMA
        self.EPSILON = EPSILON
        self.BATCH_SIZE = BATCH_SIZE
        self.INPUT_DIM = input_dim
        self.NODE_DIM = node_dim

    def get_q_value(self, head_emb, op, node_emb):
        return self.eval_net(head_emb, op, node_emb)

    def get_q_value_next_state(self, head_emb, op, node_emb):
        return self.target_net(head_emb, op, node_emb)

    def store_transition(self, s, o, a, r, s_, o_, a_):
        if(r > 0):
            self.store_pos_transition(s, o, a, r, s_, o_, a_)
            if(self.pos_memory_counter < self.MEMORY_CAPACITY):
                self.store_pos_transition(s, o, a, r, s_, o_, a_)
        else:
            self.store_neg_transition(s, o, a, r, s_, o_, a_)
            if(self.neg_memory_counter < self.MEMORY_CAPACITY):
                self.store_neg_transition(s, o, a, r, s_, o_, a_)

    def store_pos_transition(self, s, o, a, r, s_, o_, a_):
        transition = np.hstack((s, o, a, r, s_, o_, a_))
        index = self.pos_memory_counter % self.MEMORY_CAPACITY
        self.pos_memory[index, :] = transition
        self.pos_memory_counter += 1

    def store_neg_transition(self, s, o, a, r, s_, o_, a_):
        transition = np.hstack((s, o, a, r, s_, o_, a_))
        index = self.neg_memory_counter % self.MEMORY_CAPACITY
        self.neg_memory[index, :] = transition
        self.neg_memory_counter += 1

    def learn(self):
        if self.learn_step_counter % self.TARGET_REPLACE_ITER == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter += 1
        sample_index = np.random.choice(self.MEMORY_CAPACITY, self.BATCH_SIZE)

        neg_sample = self.neg_memory[sample_index, :]
        pos_sample = self.pos_memory[sample_index, :]
        b_memory = np.concatenate((neg_sample,pos_sample), axis=0)

        #input dim=  node_dim = 
        b_s = torch.FloatTensor(b_memory[:, :self.INPUT_DIM])
        b_o = torch.IntTensor(b_memory[:, self.INPUT_DIM : self.INPUT_DIM+1])
        b_a = torch.FloatTensor(b_memory[:, self.INPUT_DIM+1 : self.INPUT_DIM+self.NODE_DIM+1])
        b_r = torch.FloatTensor(b_memory[:, self.INPUT_DIM+self.NODE_DIM+1 : self.INPUT_DIM+self.NODE_DIM+2])
        b_s_ = torch.FloatTensor(b_memory[:, self.INPUT_DIM+self.NODE_DIM+2 : 2*self.INPUT_DIM+self.NODE_DIM+2])
        b_o_ = torch.IntTensor(b_memory[:, 2*self.INPUT_DIM+self.NODE_DIM+2 : 2*self.INPUT_DIM+self.NODE_DIM+3])
        b_a_ = torch.FloatTensor(b_memory[:, -self.NODE_DIM:])

        q_eval = self.eval_net(b_s,b_o,b_a)
        q_next = self.target_net(b_s_,b_o_,b_a_).detach()
        q_target = b_r + self.GAMMA * q_next

        loss = self.loss_func(q_eval, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
