import numpy as np
import torch


def action_change(action):
    actions = []
    for i in range(len(action)):
        aa = [0, 0, 0, 0, 0, 0, 0]
        action_bin = bin(action[i]).replace('0b', '')
        action_bin_a = []
        for j in action_bin:
            action_bin_a.append(int(j))
        for k in range(len(action_bin_a)):
            aa[len(aa) - k - 1] += action_bin_a[len(action_bin_a) - k - 1]
        actions.append(aa)
    return actions


class SENV:
    def __init__(self, defender_reward_net, attacker_reward_net, feedback, physical):
        # self.state = np.zeros(5).tolist()
        self.state = np.zeros(8).tolist()
        self.defender_reward_net = defender_reward_net
        self.attacker_reward_net = attacker_reward_net
        self.feedback = feedback
        self.physical = physical
        # self.edge_s = [0, 0, 0, 1, 2, 3, 3]
        # self.edge_g = [1, 2, 3, 2, 4, 1, 2]
        # self.edge_s = [0, 0, 1, 1, 2, 3, 4, 6, 5, 3]
        # self.edge_g = [3, 4, 2, 5, 3, 4, 6, 7, 6, 5]
        self.edge_s = [0, 0, 2, 3, 4, 6, 1, 3, 5, 1]
        self.edge_g = [3, 4, 3, 4, 6, 7, 2, 5, 6, 5]
        # self.attacker_cost = [-1, -1, -1, -1, -1, -1, -1]
        # self.defender_cost = [-1, -1, -1, -1, -1, -1, -1]
        self.attacker_cost = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
        self.defender_cost = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
        # self.state_reward = [0, 3, 3, 3, 5]
        self.state_reward = [0, 0, 3, 3, 3, 3, 3, 7]
        # self.state = [np.array([1, 0, 0, 1, 0]),np.array([1, 0, 0, 1, 0])]
        # self.state = [1, 0, 0, 0, 0]
        self.state = [1, 1, 0, 0, 0, 0, 0, 0]

    def step(self, state, action):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # print(state)
        # defender_pick = np.argmax(action[0])
        # attacker_pick = np.argmax(action[1])
        defender_pick = action[0]
        attacker_pick = action[1]
        # print(defender_pick)
        # print(attacker_pick)
        reward = np.zeros(2)
        # for i in range(len(defender_pick)):
        #     d_p = defender_pick[i]
        #     a_p = attacker_pick[i]
        #     if d_p == 1:
        #         reward[0] += self.defender_cost[i]
        #
        #     if a_p == 1:
        #         reward[1] += self.attacker_cost[i]
        #
        #     if a_p == 1 and d_p == 0:
        #         start = self.edge_s[i]
        #         goal = self.edge_g[i]
        #         if state[start] == 1 and state[goal] == 0:
        #             state[goal] = 1
        #             reward[1] += self.state_reward[goal]
        #             reward[0] -= self.state_reward[goal]
        # self.state = state
        reward[0] += self.defender_cost[defender_pick]
        reward[1] += self.attacker_cost[attacker_pick]
        if defender_pick == attacker_pick:
            self.state = state
        else:
            start = self.edge_s[attacker_pick]
            goal = self.edge_g[attacker_pick]
            if state[start] == 1 and state[goal] == 0:
                state[goal] = 1
                reward[1] += self.state_reward[goal]
                reward[0] -= self.state_reward[goal]
            self.state = state
        if not self.feedback:
            # d_a = [0, 0, 0, 0, 0, 0, 0]
            # a_a = [0, 0, 0, 0, 0, 0, 0]
            d_a = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            a_a = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            d_a[defender_pick] = 1
            a_a[attacker_pick] = 1
            input_l = []
            # input_l.append(state[1:])
            input_l.append(state[2:])
            input_l.append(d_a)
            input_l.append(a_a)
            input_flat = []
            for i in range(len(input_l)):
                for j in input_l[i]:
                    input_flat.append(j)
            net_input = torch.tensor(input_flat, dtype=torch.float, device=device)
            reward[0] = self.defender_reward_net(net_input).detach().cpu().numpy() + self.defender_cost[defender_pick]
        if not self.physical:
            # d_a = [0, 0, 0, 0, 0, 0, 0]
            # a_a = [0, 0, 0, 0, 0, 0, 0]
            d_a = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            a_a = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            d_a[defender_pick] = 1
            a_a[attacker_pick] = 1
            input_l = []
            # input_l.append(state[1:])
            input_l.append(state[2:])
            input_l.append(d_a)
            input_l.append(a_a)
            input_flat = []
            for i in range(len(input_l)):
                for j in input_l[i]:
                    input_flat.append(j)
            net_input = torch.tensor(input_flat, dtype=torch.float, device=device)
            reward[1] = self.attacker_reward_net(net_input).detach().cpu().numpy()
        # print(self.state)
        # print(reward)
        return self.state, reward.tolist()

    def reset(self):
        # self.state = [1, 0, 0, 0, 0]
        self.state = [1, 1, 0, 0, 0, 0, 0, 0]
        return self.state
