from maddpgmain import *
import torch
import torch.nn as nn
from scipy.linalg import sqrtm
from conjugategrad import *
from spsa import *

class reward_estimator(nn.Module):
    def __init__(self,num_input):
        super(reward_estimator, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(num_input, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )
        self.mls = nn.MSELoss()
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)

    def forward(self, x):
        # x = x.to(torch.float64)
        return self.fc(x)

def cumulative_reward_adversary(adversary_estimator,dics):
    env = simple_adversary_v3.parallel_env(N=2, max_cycles=25, continuous_actions=True)
    observations, _ = env.reset()
    agent_names = env.agents
    obs_dim = []
    action_dim = []
    for name in agent_names:
        obs_dim.append(env.observation_space(name).shape[0])
        action_dim.append(env.action_space(name).shape[0])
    agent_n = [MADDPG(3, obs_dim, action_dim, agent_id) for agent_id in range(3)]
    cumulative_reward = 0
    for i in range(3):
        agent = agent_n[i]
        agent.actor.load_state_dict(copy.deepcopy(dics[i]))
    for _ in range(10):
        observations, _ = env.reset()
        obs_n = []
        agent_names = env.agents
        for name in agent_names:
            obs_n.append(observations[name])
        while env.agents:
            a_n = [agent.choose_action(obs, noise_std=0).astype(float) for agent, obs in zip(agent_n, obs_n)]
            a_in = {}
            j = 0
            for name in agent_names:
                a_in[name] = copy.deepcopy(a_n[j])
                a_in[name] = a_in[name].astype('float32')
                j += 1
            # print(a_in)
            observations_next, rewards, dones, _, _ = env.step(a_in)
            obs_next_n = []
            r_n = []
            done_n = []
            for name in agent_names:
                obs_next_n.append(observations_next[name])
                r_n.append(rewards[name])
                done_n.append(dones[name])

            r_0_input = torch.tensor(obs_n[0], dtype=torch.float, device=device)
            adversary_r = adversary_estimator(r_0_input)
            cumulative_reward += adversary_r

            obs_n = obs_next_n

    return cumulative_reward/10

def cumulative_reward_agent(dics):
    env = simple_adversary_v3.parallel_env(N=2, max_cycles=25, continuous_actions=True)
    observations, _ = env.reset()
    agent_names = env.agents
    obs_dim = []
    action_dim = []
    for name in agent_names:
        obs_dim.append(env.observation_space(name).shape[0])
        action_dim.append(env.action_space(name).shape[0])
    agent_n = [MADDPG(3, obs_dim, action_dim, agent_id) for agent_id in range(3)]
    cumulative_reward = 0
    for i in range(3):
        agent = agent_n[i]
        agent.actor.load_state_dict(copy.deepcopy(dics[i]))
    for _ in range(10):
        observations, _ = env.reset()
        obs_n = []
        agent_names = env.agents
        for name in agent_names:
            obs_n.append(observations[name])
        while env.agents:
            a_n = [agent.choose_action(obs, noise_std=0).astype(float) for agent, obs in zip(agent_n, obs_n)]
            a_in = {}
            j = 0
            for name in agent_names:
                a_in[name] = copy.deepcopy(a_n[j])
                a_in[name] = a_in[name].astype('float32')
                j += 1
            # print(a_in)
            observations_next, rewards, dones, _, _ = env.step(a_in)
            obs_next_n = []
            r_n = []
            done_n = []
            for name in agent_names:
                obs_next_n.append(observations_next[name])
                r_n.append(rewards[name])
                done_n.append(dones[name])
            cumulative_reward += r_n[1]
            obs_n = obs_next_n

    return cumulative_reward/10


def feature_expectation(reward_estimator, dics, agent_sim, adversary_sim):
    f_e = np.zeros(10)
    env = simple_adversary_v3.parallel_env(N=2, max_cycles=25, continuous_actions=True)
    observations, _ = env.reset()
    agent_names = env.agents
    obs_dim = []
    action_dim = []
    for name in agent_names:
        obs_dim.append(env.observation_space(name).shape[0])
        action_dim.append(env.action_space(name).shape[0])
    agent_n = [MADDPG(3, obs_dim, action_dim, agent_id) for agent_id in range(3)]
    for i in range(3):
        agent = agent_n[i]
        agent.actor.load_state_dict(copy.deepcopy(dics[i]))
    for _ in range(10):
        observations, _ = env.reset()
        obs_n = []
        agent_names = env.agents
        for name in agent_names:
            obs_n.append(observations[name])
        while env.agents:
            a_n = [agent.choose_action(obs, noise_std=0).astype(float) for agent, obs in zip(agent_n, obs_n)]
            a_in = {}
            j = 0
            for name in agent_names:
                a_in[name] = copy.deepcopy(a_n[j])
                a_in[name] = a_in[name].astype('float32')
                j += 1
            # print(a_in)
            observations_next, rewards, dones, _, _ = env.step(a_in)
            obs_next_n = []
            r_n = []
            done_n = []
            for name in agent_names:
                obs_next_n.append(observations_next[name])
                r_n.append(rewards[name])
                done_n.append(dones[name])
            if agent_sim:
                r_1_input = torch.tensor(obs_n[1], dtype=torch.float, device=device)
                agent_r_1 = reward_estimator(r_1_input)
                r_2_input = torch.tensor(obs_n[2], dtype=torch.float, device=device)
                agent_r_2 = reward_estimator(r_2_input)
                f_1 = copy.deepcopy(torch.autograd.grad(agent_r_1, reward_estimator.parameters()))
                f_2 = copy.deepcopy(torch.autograd.grad(agent_r_2, reward_estimator.parameters()))
                f_e += (f_1[4][0].detach().cpu().numpy()+f_2[4][0].detach().cpu().numpy())/2
            if adversary_sim:
                adversary_input = torch.tensor(obs_n[0], dtype=torch.float, device=device)
                adversary_r = reward_estimator(adversary_input)
                f_a = copy.deepcopy(torch.autograd.grad(adversary_r, reward_estimator.parameters()))
                f_e += f_a[4][0].detach().cpu().numpy()
            obs_n = obs_next_n
    return f_e/10

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent_reward = reward_estimator(10).to(device)
    adversary_reward = reward_estimator(8).to(device)
    alpha = 1e-3
    for i in range(50):
        dics_e_ad = trainmaddpg(agent_reward, adversary_reward, True, False)
        for j in range(int(np.ceil((i + 1) ** 0.25 / 2))):
            re_ea = cumulative_reward_adversary(adversary_reward, dics_e_ad)
            dics_l_ad = trainmaddpg(agent_reward, adversary_reward, True, True)
            # print('reward_check: ', reward_check)
            re_l = cumulative_reward_adversary(adversary_reward, dics_l_ad)
            loss_l = re_l - re_ea
            print('loss_l', loss_l)
            adversary_reward.opt.zero_grad()
            loss_l.backward()
            adversary_reward.opt.step()
        delta_1 = np.random.choice([-1, 1], size=10)
        delta_2 = np.random.choice([-1, 1], size=10)
        c = 0.01
        agent_param_p = copy.deepcopy(agent_reward.state_dict())
        agent_param_p['fc.4.weight'] += torch.from_numpy(c * delta_1)
        agent_param_m = copy.deepcopy(agent_reward.state_dict())
        agent_param_m['fc.4.weight'] -= torch.from_numpy(c * delta_1)
        adversary_param_p = copy.deepcopy(adversary_reward.state_dict())
        adversary_param_p['fc.4.weight'] += torch.from_numpy(c * delta_2)
        adversary_param_m = copy.deepcopy(adversary_reward.state_dict())
        adversary_param_m['fc.4.weight'] -= torch.from_numpy(c * delta_2)
        agent_reward_p = reward_estimator(10).to(device)
        agent_reward_p.load_state_dict(agent_param_p)
        agent_reward_m = reward_estimator(10).to(device)
        agent_reward_m.load_state_dict(agent_param_m)
        adversary_reward_p = reward_estimator(8).to(device)
        adversary_reward_p.load_state_dict(adversary_param_p)
        adversary_reward_m = reward_estimator(8).to(device)
        adversary_reward_m.load_state_dict(adversary_param_m)
        dics_agent_p= trainmaddpg(agent_reward_p, adversary_reward, True, True)
        dics_agent_m = trainmaddpg(agent_reward_m, adversary_reward, True, True)
        dics_adversary_p = trainmaddpg(agent_reward, adversary_reward_p, True, True)
        dics_adversary_m = trainmaddpg(agent_reward, adversary_reward_m, True, True)
        f_agent_p = cumulative_reward_agent(dics_agent_p)
        f_agent_m = cumulative_reward_agent(dics_agent_m)
        f_adversary_p = cumulative_reward_agent(dics_adversary_p)
        f_adversary_m = cumulative_reward_agent(dics_adversary_m)
        l_agent_p = feature_expectation(agent_reward, dics_agent_p,True,False)
        l_agent_m = feature_expectation(agent_reward, dics_agent_m,True,False)
        l_adversary_p = feature_expectation(adversary_reward, dics_adversary_p,False,True)
        l_adversary_m = feature_expectation(adversary_reward, dics_adversary_m,False,True)
        df_agent = spsa(np.array([f_agent_p]), np.array([f_agent_m]), delta_1, c).reshape(10)
        df_adversary = spsa(np.array([f_adversary_p]), np.array([f_adversary_m]), delta_2, c).reshape(10)
        dl_agent_adversary = spsa(l_agent_p, l_agent_m, delta_2, c)
        dl_adversary_adversary = spsa(l_adversary_p, l_adversary_m, delta_2, c)
        dl_aa_p = (dl_adversary_adversary + dl_adversary_adversary.T) / 2
        dl_aa_pd = np.real(sqrtm(np.dot(dl_aa_p, dl_aa_p) + 0.001 * np.eye(10)))
        inv = conjugate(dl_aa_pd, df_adversary)
        g_f = df_agent - dl_agent_adversary @ inv
        print('ep', i, 'g_f',g_f)
        agent_param = copy.deepcopy(agent_reward.state_dict())
        agent_param['fc.4.weight'] -= alpha * g_f
        agent_reward.load_state_dict(agent_param)
        torch.save(agent_reward.state_dict(), 'mpe1/agent_' + str(i) + '.pt')
        torch.save(adversary_reward.state_dict(), 'mpe1/adversary_' + str(i) + '.pt')

