from configparser import ConfigParser
import argparse

import torch
import gymnasium as gym
import numpy as np
import os

from utils.utils import eval_model
from utils.utils import get_save_dir

from ppo_gini_isclip import PPO_Gini

from utils.utils import make_transition, Dict
os.makedirs('./model_weights', exist_ok=True)

import sys
sys.path.append('..')

def readParser():
    parser = argparse.ArgumentParser(description='parameters')
    parser.add_argument("--env_name", type=str, default='InvertedPendulum-v4')
    parser.add_argument('--train', type=bool, default=True, help="(default: True)")
    parser.add_argument('--render', type=bool, default=False, help="(default: False)")
    parser.add_argument('--epochs', type=int, default=200, help='number of epochs, (default: 3000)')
    parser.add_argument("--load", type=str, default='no', help='load network name in ./model_weights')
    parser.add_argument("--save_interval", type=int, default=1000, help='save interval(default: 100)')
    parser.add_argument("--print_interval", type=int, default=100, help='print interval(default : 20)')
    parser.add_argument("--use_cuda", type=bool, default=True, help='cuda usage(default : True)')
    parser.add_argument("--reward_scaling", type=float, default=1.0, help='reward scaling(default : 1.0)')
    parser.add_argument("--experiment_num", type=int, default=5, help='reward scaling(default : 1.0)')
    return parser.parse_args()


def experiment(args, agent_args, agent, env, eval_env, save_dir, device):
    score_lst = []
    total_episodes = 0
    eval_r_lst, eval_right_lst = [], []

    score = 0.0
    state, _ = env.reset()
    cur_traj = 0
    n = 0
    for n_epi in range(args.epochs):

        episodes_buf = {'state': [], 'action': [], 'reward': [], 'log_prob': []}
        state_lst, action_lst, reward_lst, log_prob_lst = [], [], [], []

        for t in range(agent_args.traj_length):
            if args.render:
                env.render()
            mu, sigma = agent.get_action(torch.from_numpy(state).float().to(device))
            dist = torch.distributions.Normal(mu, sigma[0])
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(-1, keepdim=True)
            next_state, reward, done, _, info = env.step(action.cpu().numpy())
            cur_traj += 1
            n += 1
            if n % 10000 == 0:
                eval_r, eval_right = eval_model(eval_env, agent, device)
                print(n/10000, 'eval_return:', eval_r.mean(), eval_right.mean())
                eval_r_lst.append(eval_r.mean())
                eval_right_lst.append(eval_right.mean())

                with open(save_dir + 'eval_r.npy', 'wb') as f:
                    np.save(f, np.array(eval_r_lst))
                with open(save_dir + 'eval_right.npy', 'wb') as f:
                    np.save(f, np.array(eval_right_lst))

            transition = make_transition(state, \
                                         action.cpu().numpy(), \
                                         np.array([reward * args.reward_scaling]), \
                                         next_state, \
                                         np.array([done]), \
                                         log_prob.detach().cpu().numpy() \
                                         )
            agent.put_data(transition)
            state_lst.append(state)
            action_lst.append(action.cpu().numpy())
            reward_lst.append(reward * args.reward_scaling)
            log_prob_lst.append(log_prob.detach().cpu().numpy())

            score += reward
            if done or cur_traj == 500:
                state, _ = env.reset()
                cur_traj = 0
                score_lst.append(score)
                score = 0

                total_episodes += 1

                episodes_buf['state'].append(state_lst)
                episodes_buf['action'].append(action_lst)
                episodes_buf['reward'].append(reward_lst)
                episodes_buf['log_prob'].append(log_prob_lst)
                state_lst, action_lst, reward_lst, log_prob_lst = [], [], [], []



            else:
                state = next_state
        agent.put_episodes(episodes_buf)
        agent.train_net(n_epi)
        # if (n_epi + 1) % args.print_interval == 0:
        #     print("# of epoch :{}, avg score : {:.1f}".format(n_epi, sum(score_lst) / len(score_lst)))
        #     score_lst = []
        # if (n_epi + 1) % args.save_interval == 0:
        #     torch.save(agent.state_dict(), './model_weights/agent_' + str(n_epi))

def main():
    args = readParser()
    parser = ConfigParser()
    parser.read('config.ini')
    agent_args = Dict(parser, 'ppo')

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if args.use_cuda == False:
        device = 'cpu'

    env = gym.make(args.env_name)
    eval_env = gym.make(args.env_name)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]

    agent = PPO_Gini(device, state_dim, action_dim, agent_args)

    if (torch.cuda.is_available()) and (args.use_cuda):
        agent = agent.cuda()

    if args.load != 'no':
        agent.load_state_dict(torch.load("./model_weights/" + args.load))


    print('env:', args.env_name, 'algo:', 'ppo')


    save_dir = get_save_dir(args.env_name, agent_args.actor_lr, args.experiment_num)
    os.makedirs(save_dir, exist_ok=True)

    experiment(args, agent_args, agent, env, eval_env, save_dir, device)


main()