import os
from ant_utils import *
from gym.envs.mujoco.ant_v3 import AntEnv
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

from PPO import PPO


def sim(agent, env, task_params, T, r_tot, device, mode=0):

    obs = env.reset()

    for t in range(T):

        # if agent.oracle_repr:
        #     st = np.zeros((1, 4))
        #     st[0, 0] = obs[0][0]
        #     st[0, 1] = obs[0][2]
        #     st[0, 2] = obs[0][3]
        #     st[0, 3] = obs[0][5]
        #     st = torch.from_numpy(st).float().to(device)
        # else:
        #     st = torch.from_numpy(np.expand_dims(np.transpose(obs[0], (2, 0, 1)), 0)).to(device)

        st = np.expand_dims(np.concatenate((obs[:27], task_params)), 0)  # TODO: task_params
        st = torch.from_numpy(st).float().to(device)

        at, logprob, sigma = agent.get_action(st, test=(mode == 1))
        obs, reward, done, _ = env.step(at[0].detach().cpu().numpy())

        r_tot += reward/T

        if mode == 0:
            if t % T == (T - 1):
                done = True

            agent.push_batchdata(st.detach().cpu(), at.detach().cpu(), logprob.detach().cpu(), reward, done)

        if done:
            break

    return r_tot


def train(writer, device):

    max_reward = -10000

    EPOCHS = 100000
    test_frq = 10
    test_epochs = 10
    T = 100
    envs_batch = 10
    
    p_min = 0.2
    p_max = 0.6

    z_dim = 27+8
    a_dim = 8
    a_max = 1

    agent = PPO(z_dim, a_dim, a_max, device)

    for epoch in tqdm(range(EPOCHS)):

        r_tot = 0
        for i in range(envs_batch):
            ant_xml_dir = 'ant_xml'
            xml_file = f'{ant_xml_dir}/ant_tmp_replace.xml'
            task_params = np.random.uniform(size=2) * (p_max-p_min) + p_min
            change_xml(xml_file, ant_xml_dir, task_params[0], task_params[1])
            env = AntEnv(xml_file='/home/alfredo/PycharmProjects/HyperMAML/RL/'+xml_file)

            r_tot = sim(agent, env, task_params, T, r_tot, device, mode=0)

            env.close()

        v_loss, h_loss = agent.update()
        writer.add_scalar("Train/R_tot_g_train", r_tot / envs_batch, epoch)

        writer.add_scalar("v_loss", v_loss, epoch)
        writer.add_scalar("h_loss", h_loss, epoch)
        agent.clear_batchdata()

        if epoch % test_frq == (test_frq-1):

            r_tot = 0
            for test_epoch in range(test_epochs):
                ant_xml_dir = 'ant_xml'
                xml_file = f'{ant_xml_dir}/ant_tmp_replace.xml'
                task_params = np.random.uniform(size=2) * (p_max - p_min) + p_min
                change_xml(xml_file, ant_xml_dir, task_params[0], task_params[1])
                env = AntEnv(xml_file='/home/alfredo/PycharmProjects/HyperMAML/RL/'+xml_file)
                r_tot = sim(agent, env, task_params, T, r_tot, device, mode=1)

                env.close()

            writer.add_scalar("Test/R_tot_g_test", r_tot / test_epochs, epoch)

            if r_tot > max_reward:
                max_reward = r_tot
                fname_psi = '/home/alfredo/PycharmProjects/HyperMAML/RL/PPO_model.mdl'
                torch.save(agent.old_policy.state_dict(), fname_psi)

    writer.close()


if __name__ == '__main__':

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    seed = 0

    torch.manual_seed(seed)
    np.random.seed(seed)

    writer = SummaryWriter("./logs_rl/1")

    train(writer, device)














