import numpy as np
from security import *


class Qagent:
    def __init__(self, num_state, num_action, epsilon, alpha, gamma):
        self.Q_table = np.zeros([num_state, num_action])
        self.num_action = num_action
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma

    def choose_action(self, state, explore):
        flag = 10
        if explore:
            flag = np.random.random()

        if flag < self.epsilon:
            action = np.random.randint(0, self.num_action)
        else:
            sort = np.argsort(self.Q_table[state, 0: self.num_action])
            if np.random.random() < 0.6:
                action = sort[self.num_action-1]
            else:
                action = sort[self.num_action-2]
        return action

    def update(self, state, action, reward, next_state):
        gd = reward + self.gamma * self.Q_table[next_state].max() - self.Q_table[state, action]
        self.Q_table[state, action] += self.alpha * gd
        # print(state,action, self.Q_table[state, action])


class MQ:
    def __init__(self, num_agent, num_state, num_action, epsilon, alpha, gamma):
        self.num_agent = num_agent
        self.agents = []
        for i in range(num_agent):
            self.agents.append(Qagent(num_state[i], num_action[i], epsilon, alpha, gamma))
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma

    def choose_actions(self, state, explore):
        actions = []
        for agent in self.agents:
            actions.append(agent.choose_action(state, explore))
        return actions

    def update(self, state, action, reward, next_state):
        for i in range(self.num_agent):
            self.agents[i].update(state, action[i], reward[i], next_state)


def state_change(state):
    state_in = 0
    for i in range(len(state)):
        if state[i] == 1:
            state_in += 2 ** (len(state) - i - 1)
    return state_in


def action_change(action):
    actions = []
    for i in range(len(action)):
        aa = [0, 0, 0, 0, 0, 0, 0,0,0,0]
        action_bin = bin(action[i]).replace('0b', '')
        action_bin_a = []
        for j in action_bin:
            action_bin_a.append(int(j))
        for k in range(len(action_bin_a)):
            aa[len(aa) - k - 1] += action_bin_a[len(action_bin_a) - k - 1]
        actions.append(aa)
    return actions


def train_q(defender_reward_net, attacker_reward_net, feedback, physical, save, iteration_number):
    epsilon = 0.3
    alpha = 0.1
    gamma = 0.9
    num_episodes = 100000
    # num_state = [2 ** 5, 2 ** 5]
    num_state = [2 ** 8, 2 ** 8]
    # num_action = [7, 7]
    num_action = [10, 10]
    num_agent = 2
    mq = MQ(num_agent, num_state, num_action, epsilon, alpha, gamma)
    env = SENV(defender_reward_net, attacker_reward_net, feedback, physical)
    for i in range(num_episodes):
        e_reward = [0, 0]
        state = env.reset()
        for j in range(8):
            state_in = state_change(state)
            # print(state)
            # print(state_in)
            action = mq.choose_actions(state_in,True)
            #print(action)
            # action_in = action_change(action)
            # print('state', state)
            # print('action', action)
            # print('action', action_in)
            next_state, reward = env.step(state, action)
            # print(reward)
            e_reward[0] += reward[0]
            e_reward[1] += reward[1]
            next_state_in = state_change(next_state)
            mq.update(state_in, action, reward, next_state_in)
            state = next_state
        #print(e_reward)
    # d= 0
    # a=0
    # count = 0
    trajs = []
    reward_out = [0,0]
    for i in range(10):
        state = env.reset()
        e_reward = [0, 0]
        traj = []
        for j in range(5):
            state_in = state_change(state)
            #print('state', state)
            action = mq.choose_actions(state_in,False)
            traj.append((state.copy(),action))
            #print('action', action)
            # action_in = action_change(action)
            next_state, reward = env.step(state, action)
            #print(reward)
            e_reward[0] += reward[0]
            e_reward[1] += reward[1]
            next_state_in = state_change(next_state)

            state = next_state
        trajs.append(traj)
        reward_out[0] += e_reward[0]
        reward_out[1] += e_reward[1]
    #     print(state)
    #     if state[4] == 1:
    #         count+=1
    #     d+=e_reward[0]
    #     a+=e_reward[1]
    # print(d/10,a/10,count/10)
    reward_out[0] /= 10
    reward_out[1] /= 10
    return trajs, reward_out