import os
import torch
import gymnasium as gym
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from fed_ppo import FedPPO
from fed_trpo import FedTRPO
from fed_trpo_admm import FedTRPO_ADMM
from stable_baselines3.common import results_plotter
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import VecMonitor, VecNormalize


# def make_env(env_id, rank, seed=0):
#     """
#     Utility function for multiprocessed env.
#
#     :param env_id: (str) the environment ID
#     :param num_env: (int) the number of environments you wish to have in subprocesses
#     :param seed: (int) the inital seed for RNG
#     :param rank: (int) index of the subprocess
#     """
#
#     def _init():
#         env = gym.make(env_id)
#         env.seed(seed + rank)
#         return env
#
#     set_random_seed(seed)
#     return _init


def reward_mixer(n_clients=8):
    env_id = "Swimmer-v4"
    algo_type = "TRPO_admm"
    log_dir_l = "./fed_tmp/" + algo_type + '/' + env_id + '/clients_' + str(n_clients) + '/'
    log_dir_r = '.monitor.csv'
    length = [0]
    for i in range(n_clients):
        log_dir = log_dir_l + str(i) + log_dir_r
        data_i = pd.read_csv(log_dir).iloc[:, 0]
        length.append(len(data_i))
        if i == 0:
            data = data_i
        else:
            if len(data_i) <= len(data):
                data[:len(data_i)] = data[:len(data_i)] + data_i
            else:
                data += data_i[:len(data)]
                data = pd.concat([data, data_i[len(data):]])
    length.sort()
    for i in range(len(length) - 1):
        data[length[i]:length[i+1]] /= (n_clients - i)
    data.to_csv(log_dir_l + 'avg_monitor.csv')


if __name__ == "__main__":
    env_id = "Swimmer-v4"
    env_kwargs = {}  # visualize {"render_mode": "human"}
    algo_type = "TRPO_admm"
    iters = 1000
    n_clients = 1
    rho = 1e-3
    log_dir = "./fed_tmp/" + algo_type + '/' + env_id + '/clients_' + str(n_clients)
    os.makedirs(log_dir, exist_ok=True)
    timesteps = 1_000_000
    # Create the vectorized environment
    # env = SubprocVecEnv([make_env(env_id, i) for i in range(n_clients)])
    env = make_vec_env(env_id, monitor_dir=log_dir, n_envs=n_clients,
                       seed=0, env_kwargs=env_kwargs, vec_env_cls=SubprocVecEnv)
    # env = VecMonitor(env, log_dir)
    
    if algo_type == "TRPO":
        # model_path = f"{log_dir}/0.zip"
        # model = TRPO.load(log_dir, env=env)
        model = FedTRPO("MlpPolicy", env, verbose=1, n_clients=n_clients)
    elif algo_type == "TRPO_admm":
        model = FedTRPO_ADMM("MlpPolicy", env, verbose=1, n_clients=n_clients,
                             rho=rho, alpha=alpha)
    elif algo_type == "PPO":
        model = FedPPO("MlpPolicy", env, verbose=1, n_clients=n_clients)
    else:
        print("algo not defined!")
    model.learn(total_timesteps=timesteps*n_clients)
    model.save(log_dir + '/' + env_id + '_' + str(iters))
    # stats_path = os.path.join(log_dir, env_id+'_'+str(iters)+".pkl")
    # env.save(stats_path)  # naive env only
    # del model
    # model = PPO.load(log_dir+'/'+env_id+'_'+str(iters))
    # env = make_vec_env(env_id, n_envs=n_clients, seed=0, vec_env_cls=SubprocVecEnv)
    # env = VecNormalize.load(stats_path, env)
    
    obs = env.reset()
    for iter in range(iters):
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
    
    results_plotter.plot_results([log_dir], timesteps*n_clients, x_axis='episodes',
                                 task_name=algo_type + ' ' + env_id, figsize=(15, 10))
