from MAPPO_SMAC_main import *
import torch
import torch.nn as nn
from scipy.linalg import sqrtm
from conjugategrad import *
from spsa import *
import copy
from spsa import *
from smac.env import StarCraft2Env

class reward_estimator(nn.Module):
    def __init__(self,num_input):
        super(reward_estimator, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(num_input, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )
        self.mls = nn.MSELoss()
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)

    def forward(self, x):
        # x = x.to(torch.float64)
        x = self.fc(x)
        x = torch.clamp(x,min =-1,max=1)
        return x


def get_arg():
    parser = argparse.ArgumentParser("Hyperparameters Setting for MAPPO in SMAC environment")
    parser.add_argument("--max_train_steps", type=int, default=int(1e5), help=" Maximum number of training steps")
    parser.add_argument("--evaluate_freq", type=float, default=5000,
                        help="Evaluate the policy every 'evaluate_freq' steps")
    parser.add_argument("--evaluate_times", type=float, default=32, help="Evaluate times")
    parser.add_argument("--save_freq", type=int, default=int(1e5), help="Save frequency")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size (the number of episodes)")
    parser.add_argument("--mini_batch_size", type=int, default=8, help="Minibatch size (the number of episodes)")
    parser.add_argument("--rnn_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of RNN")
    parser.add_argument("--mlp_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of MLP")
    parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
    parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
    parser.add_argument("--lamda", type=float, default=0.95, help="GAE parameter")
    parser.add_argument("--epsilon", type=float, default=0.2, help="GAE parameter")
    parser.add_argument("--K_epochs", type=int, default=15, help="GAE parameter")
    parser.add_argument("--use_adv_norm", type=bool, default=True, help="Trick 1:advantage normalization")
    parser.add_argument("--use_reward_norm", type=bool, default=True, help="Trick 3:reward normalization")
    parser.add_argument("--use_reward_scaling", type=bool, default=False,
                        help="Trick 4:reward scaling. Here, we do not use it.")
    parser.add_argument("--entropy_coef", type=float, default=0.01, help="Trick 5: policy entropy")
    parser.add_argument("--use_lr_decay", type=bool, default=True, help="Trick 6:learning rate Decay")
    parser.add_argument("--use_grad_clip", type=bool, default=True, help="Trick 7: Gradient clip")
    parser.add_argument("--use_orthogonal_init", type=bool, default=True, help="Trick 8: orthogonal initialization")
    parser.add_argument("--set_adam_eps", type=float, default=True, help="Trick 9: set Adam epsilon=1e-5")
    parser.add_argument("--use_relu", type=float, default=True, help="Whether to use relu, if False, we will use tanh")
    parser.add_argument("--use_rnn", type=bool, default=True, help="Whether to use RNN")
    parser.add_argument("--add_agent_id", type=float, default=False,
                        help="Whether to add agent_id. Here, we do not use it.")
    parser.add_argument("--use_agent_specific", type=float, default=True,
                        help="Whether to use agent specific global state.")
    parser.add_argument("--use_value_clip", type=float, default=False, help="Whether to use value clip.")

    args = parser.parse_args()
    return args

def train_actor(learner_reward, expert_reward, estimate_learner, estimate_expert):
    env_names = ['2s_vs_1sc', '8m', '2s3z']
    env_index = 0
    args = get_arg()
    runner = Runner_MAPPO_SMAC(args, env_name=env_names[env_index], number=1, seed=0, learner_reward=learner_reward,
                               expert_reward=expert_reward, estimate_learner=estimate_learner, estimate_expert=estimate_expert)
    actor = runner.run()
    return actor.state_dict()

def cumulative_reward_expert(expert_estimator,actor):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = StarCraft2Env(map_name="2s_vs_1sc")
    args = get_arg()
    runner = Runner_MAPPO_SMAC(args, env_name='2s_vs_1sc', number=1, seed=0, learner_reward=0,
                               expert_reward=0, estimate_learner=False, estimate_expert=False)

    runner.agent_n.actor.load_state_dict(copy.deepcopy(actor))
    cumulative_reward = 0
    for e in range(10):
        env.reset()
        terminated = False
        episode_reward = 0
        len = 0
        while not terminated:
            obs = env.get_obs()
            avail_a_n = env.get_avail_actions()
            a_n, _ = runner.agent_n.choose_action(obs, avail_a_n, evaluate=False)
            _, terminated, _ = env.step(a_n)
            obs_e_in = torch.tensor(obs[1], dtype=torch.float, device=device)
            a_e_ini = np.zeros(7)
            a_e_choose = a_n[1]
            a_e_ini[a_e_choose] = 1
            a_e_in = torch.tensor(a_e_ini, dtype=torch.float, device=device)
            e_in = torch.cat((obs_e_in, a_e_in))
            r = expert_estimator(e_in)
            len += 1
            episode_reward += r
        cumulative_reward += episode_reward / len
    env.close()
    return cumulative_reward/10

def cumulative_reward_learner(actor):
    env = StarCraft2Env(map_name="2s_vs_1sc")
    args = get_arg()
    runner = Runner_MAPPO_SMAC(args, env_name='2s_vs_1sc', number=1, seed=0, learner_reward=0,
                               expert_reward=0, estimate_learner=False, estimate_expert=False)

    runner.agent_n.actor.load_state_dict(copy.deepcopy(actor))
    cumulative_reward = 0
    for e in range(10):
        env.reset()
        terminated = False
        episode_reward = 0
        len = 0
        while not terminated:
            obs = env.get_obs()
            avail_a_n = env.get_avail_actions()
            a_n, _ = runner.agent_n.choose_action(obs, avail_a_n, evaluate=False)
            r, terminated, _ = env.step(a_n)
            len += 1
            episode_reward += r
        cumulative_reward += episode_reward / len
    env.close()
    return cumulative_reward/10

def feature_expectation(reward_estimator, actor, learner_sim, expert_sim):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    f_e = np.zeros(10)
    env = StarCraft2Env(map_name="2s_vs_1sc")
    args = get_arg()
    runner = Runner_MAPPO_SMAC(args, env_name='2s_vs_1sc', number=1, seed=0, learner_reward=0,
                               expert_reward=0, estimate_learner=False, estimate_expert=False)

    runner.agent_n.actor.load_state_dict(copy.deepcopy(actor))
    for e in range(10):
        env.reset()
        terminated = False
        episode_reward = 0
        len = 0
        f_ep = np.zeros(10)
        while not terminated:
            obs = env.get_obs()
            avail_a_n = env.get_avail_actions()
            a_n, _ = runner.agent_n.choose_action(obs, avail_a_n, evaluate=False)
            r, terminated, _ = env.step(a_n)
            len += 1
            if learner_sim:
                obs_l_in = torch.tensor(obs[0], dtype=torch.float, device=device)
                a_l_ini = np.zeros(7)
                a_l_choose = a_n[0]
                a_l_ini[a_l_choose] = 1
                a_l_in = torch.tensor(a_l_ini, dtype=torch.float, device=device)
                l_in = torch.cat((obs_l_in, a_l_in))
                learner_r= reward_estimator(l_in)
                f= copy.deepcopy(torch.autograd.grad(learner_r, reward_estimator.parameters()))
                f_ep += f[4][0].detach().cpu().numpy()
            if expert_sim:
                obs_e_in = torch.tensor(obs[1], dtype=torch.float, device=device)
                a_e_ini = np.zeros(7)
                a_e_choose = a_n[1]
                a_e_ini[a_e_choose] = 1
                a_e_in = torch.tensor(a_e_ini, dtype=torch.float, device=device)
                e_in = torch.cat((obs_e_in, a_e_in))
                expert_r = reward_estimator(e_in)
                f = copy.deepcopy(torch.autograd.grad(expert_r, reward_estimator.parameters()))
                f_ep += f[4][0].detach().cpu().numpy()
        f_e += f_ep / len
    env.close()
    return f_e / 10

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    learner_reward = reward_estimator(24).to(device)
    expert_reward = reward_estimator(24).to(device)
    alpha = 3e-3
    for i in range(60):
        actor_e_e = train_actor(learner_reward, expert_reward, True, False)
        for j in range(int(np.ceil((i + 1) ** 0.25 / 2))):
            re_ea = cumulative_reward_expert(expert_reward, actor_e_e)
            actor_l_e = train_actor(learner_reward, expert_reward, True, True)
            # print('reward_check: ', reward_check)
            re_l = cumulative_reward_expert(expert_reward, actor_l_e)
            loss_l = re_l - re_ea
            print('loss_l', loss_l)
            expert_reward.opt.zero_grad()
            loss_l.backward()
            expert_reward.opt.step()
        delta_1 = np.random.choice([-1, 1], size=10)
        delta_2 = np.random.choice([-1, 1], size=10)
        c = 0.01
        learner_param_p = copy.deepcopy(learner_reward.state_dict())
        learner_param_p['fc.4.weight'] += torch.from_numpy(c * delta_1)
        learner_param_m = copy.deepcopy(learner_reward.state_dict())
        learner_param_m['fc.4.weight'] -= torch.from_numpy(c * delta_1)
        expert_param_p = copy.deepcopy(expert_reward.state_dict())
        expert_param_p['fc.4.weight'] += torch.from_numpy(c * delta_2)
        expert_param_m = copy.deepcopy(expert_reward.state_dict())
        expert_param_m['fc.4.weight'] -= torch.from_numpy(c * delta_2)
        learner_reward_p = reward_estimator(24).to(device)
        learner_reward_p.load_state_dict(learner_param_p)
        learner_reward_m = reward_estimator(24).to(device)
        learner_reward_m.load_state_dict(learner_param_m)
        expert_reward_p = reward_estimator(24).to(device)
        expert_reward_p.load_state_dict(expert_param_p)
        expert_reward_m = reward_estimator(24).to(device)
        expert_reward_m.load_state_dict(expert_param_m)
        actor_learner_p= train_actor(learner_reward_p, expert_reward, True, True)
        actor_learner_m = train_actor(learner_reward_m, expert_reward, True, True)
        actor_expert_p = train_actor(learner_reward, expert_reward_p, True, True)
        actor_expert_m = train_actor(learner_reward, expert_reward_m, True, True)
        f_learner_p = cumulative_reward_learner(actor_learner_p)
        f_learner_m = cumulative_reward_learner(actor_learner_m)
        f_expert_p = cumulative_reward_learner(actor_expert_p)
        f_expert_m = cumulative_reward_learner(actor_expert_m)
        l_learner_p = feature_expectation(learner_reward, actor_learner_p,True,False)
        l_learner_m = feature_expectation(learner_reward, actor_learner_m,True,False)
        l_expert_p = feature_expectation(expert_reward, actor_expert_p,False,True)
        l_expert_m = feature_expectation(expert_reward, actor_expert_m,False,True)
        df_learner = spsa(np.array([f_learner_p]), np.array([f_learner_m]), delta_1, c).reshape(10)
        df_expert = spsa(np.array([f_expert_p]), np.array([f_expert_m]), delta_2, c).reshape(10)
        dl_learner_expert = spsa(l_learner_p, l_learner_m, delta_2, c)
        dl_expert_expert = spsa(l_expert_p, l_expert_m, delta_2, c)
        dl_ee_p = (dl_expert_expert + dl_expert_expert.T) / 2
        dl_ee_pd = np.real(sqrtm(np.dot(dl_ee_p, dl_ee_p) + 0.001 * np.eye(10)))
        inv = conjugate(dl_ee_pd, df_expert)
        g_f = df_learner - dl_learner_expert @ inv
        print('ep', i, 'g_f',g_f)
        learner_param = copy.deepcopy(learner_reward.state_dict())
        learner_param['fc.4.weight'] -= alpha * g_f
        learner_reward.load_state_dict(learner_param)
        torch.save(learner_reward.state_dict(), 'smac11/learner_' + str(i) + '.pt')
        torch.save(expert_reward.state_dict(), 'smac11/expert_' + str(i) + '.pt')
        torch.save(actor_l_e,'smac11/actor_' + str(i) + '.pt')
        # print(i, 'one episode done')