from qlearning import *
from conjugategrad import *
from spsa import *
import numpy as np
import torch
import torch.nn as nn
import copy
from scipy.linalg import sqrtm

class reward_estimator(nn.Module):
    def __init__(self, num_inputs):
        super(reward_estimator, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(num_inputs, 1),
        )
        self.mls = nn.MSELoss()
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)

    def forward(self, x):
        # x = x.to(torch.float64)
        return self.fc(x)


def feature_expectation(trajs, reward_function):
    f_e = np.zeros(26)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for i in range(len(trajs)):
        traj = trajs[i]
        for j in range(len(traj)):
            state = traj[j][0]
            action = traj[j][1]
            # d_a = [0, 0, 0, 0, 0, 0, 0]
            # a_a = [0, 0, 0, 0, 0, 0, 0]
            d_a = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            a_a = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            d_a[action[0]] = 1
            a_a[action[1]] = 1
            input_l = []
            # input_l.append(state[1:])
            input_l.append(state[2:])
            input_l.append(d_a)
            input_l.append(a_a)
            input_flat = []
            for m in range(len(input_l)):
                for n in input_l[m]:
                    input_flat.append(n)
            net_input = torch.tensor(input_flat, dtype=torch.float, device=device)
            reward = reward_function(net_input)
            f = copy.deepcopy(torch.autograd.grad(reward, reward_function.parameters()))
            f_e += f[0][0].detach().cpu().numpy()
    return f_e / len(trajs)


def cumulative_reward_d(trajs):
    reward = np.zeros(1)
    # edge_s = [0, 0, 0, 1, 2, 3, 3]
    # edge_g = [1, 2, 3, 2, 4, 1, 2]
    # defender_cost = [-1, -1, -1, -1, -1, -1, -1]
    # state_reward = [0, 3, 3, 3, 5]
    defender_cost = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
    edge_s = [0, 0, 1, 1, 2, 3, 4, 6, 5, 3]
    edge_g = [3, 4, 2, 5, 3, 4, 6, 7, 6, 5]
    state_reward = [0, 0, 3, 3, 3, 3, 3, 7]
    for i in range(len(trajs)):
        traj = trajs[i]
        for j in range(len(traj)):
            state = traj[j][0]
            action = traj[j][1]
            defender_pick = action[0]
            attacker_pick = action[1]
            reward += defender_cost[defender_pick]
            if action[0] != action[1]:
                start = edge_s[attacker_pick]
                goal = edge_g[attacker_pick]
                if state[start] == 1 and state[goal] == 0:
                    reward -= state_reward[goal]
    return reward / len(trajs)


def cumulative_reward_a(trajs, reward_function):
    reward = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for i in range(len(trajs)):
        traj = trajs[i]
        for j in range(len(traj)):
            state = traj[j][0]
            action = traj[j][1]
            # d_a = [0, 0, 0, 0, 0, 0, 0]
            # a_a = [0, 0, 0, 0, 0, 0, 0]
            d_a = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            a_a = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            d_a[action[0]] = 1
            a_a[action[1]] = 1
            input_l = []
            # input_l.append(state[1:])
            input_l.append(state[2:])
            input_l.append(d_a)
            input_l.append(a_a)
            input_flat = []
            for m in range(len(input_l)):
                for n in input_l[m]:
                    input_flat.append(n)
            net_input = torch.tensor(input_flat, dtype=torch.float, device=device)
            reward += reward_function(net_input)
    return reward / len(trajs)


if __name__ == '__main__':
    # num_input = 4 + 2 * 7
    num_input = 6 + 2 * 10
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    defender_reward = reward_estimator(num_input).to(device)
    attacker_reward = reward_estimator(num_input).to(device)
    alpha = 1e-3
    for i in range(50):
        # attacker_reward_dict = copy.deepcopy(attacker_reward.state_dict())
        trajs_ea,_ = train_q(defender_reward, attacker_reward, False, True, False, 0)
        for j in range(int(np.ceil((i + 1) ** 0.25 / 2))):
            re_ea = cumulative_reward_a(trajs_ea, attacker_reward)
            # defender_reward_dict = copy.deepcopy(defender_reward.state_dict())
            trajs_l, reward_check = train_q(defender_reward, attacker_reward, False, False, False, 0)
            print('reward_check: ', reward_check)
            re_l = cumulative_reward_a(trajs_l, attacker_reward)
            loss_l = re_l - re_ea
            attacker_reward.opt.zero_grad()
            loss_l.backward()
            attacker_reward.opt.step()
        delta_1 = np.random.choice([-1, 1], size=num_input)
        delta_2 = np.random.choice([-1, 1], size=num_input)
        c = 0.01
        defender_param_p = copy.deepcopy(defender_reward.state_dict())
        defender_param_p['fc.0.weight'] += torch.from_numpy(c * delta_1)
        defender_param_m = copy.deepcopy(defender_reward.state_dict())
        defender_param_m['fc.0.weight'] -= torch.from_numpy(c * delta_1)
        attacker_param_p = copy.deepcopy(attacker_reward.state_dict())
        attacker_param_p['fc.0.weight'] += torch.from_numpy(c * delta_2)
        attacker_param_m = copy.deepcopy(attacker_reward.state_dict())
        attacker_param_m['fc.0.weight'] -= torch.from_numpy(c * delta_2)
        defender_reward_p = reward_estimator(num_input).to(device)
        defender_reward_p.load_state_dict(defender_param_p)
        defender_reward_m = reward_estimator(num_input).to(device)
        defender_reward_m.load_state_dict(defender_param_m)
        attacker_reward_p = reward_estimator(num_input).to(device)
        attacker_reward_p.load_state_dict(attacker_param_p)
        attacker_reward_m = reward_estimator(num_input).to(device)
        attacker_reward_m.load_state_dict(attacker_param_m)
        traj_dp,_ = train_q(defender_reward_p, attacker_reward, False, False, False, 0)
        traj_dm,_ = train_q(defender_reward_m, attacker_reward, False, False, False, 0)
        traj_ap,_ = train_q(defender_reward, attacker_reward_p, False, False, False, 0)
        traj_am,_ = train_q(defender_reward, attacker_reward_m, False, False, False, 0)
        f_dp = cumulative_reward_d(traj_dp)
        f_dm = cumulative_reward_d(traj_dm)
        f_ap = cumulative_reward_d(traj_ap)
        f_am = cumulative_reward_d(traj_am)
        l_dp = feature_expectation(traj_dp, defender_reward)
        l_dm = feature_expectation(traj_dm, defender_reward)
        l_ap = feature_expectation(traj_ap, attacker_reward)
        l_am = feature_expectation(traj_am, attacker_reward)
        df_d = spsa(f_dp, f_dm, delta_1, c).reshape(num_input)
        df_a = spsa(f_ap, f_am, delta_2, c).reshape(num_input)
        dl_da = spsa(l_dp, l_dm, delta_2, c)
        dl_aa = spsa(l_ap, l_am, delta_2, c)
        dl_aa_p = (dl_aa + dl_aa.T) / 2
        dl_aa_pd = np.real(sqrtm(np.dot(dl_aa_p,dl_aa_p) + 0.001 * np.eye(num_input)))
        inv = conjugate(dl_aa_pd, df_a)
        g_f = df_d - dl_da @ inv
        defender_param = copy.deepcopy(defender_reward.state_dict())
        defender_param['fc.0.weight'] -= alpha * g_f
        defender_reward.load_state_dict(defender_param)
        torch.save(defender_reward.state_dict(), 'cs_3/defender_' + str(i)+'.pt')
        torch.save(attacker_reward.state_dict(), 'cs_3/attacker_' + str(i)+'.pt')
