import gym
import numpy as np
import torch

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy

import csv
import pickle
import os

import multiprocessing


class DarkroomEnvWrapper(gym.Env):
    def __init__(self, darkroom, log_path=None):
        self.darkroom = darkroom
        self.observation_space = gym.spaces.Discrete(darkroom.dim ** 2) # darkroom.observation_space
        self.action_space = darkroom.action_space
        print('Goal: ', darkroom.goal)

        self.rows = []
        self.log_path = log_path
        print('Log path: ', self.log_path)
        if log_path is not None:
            file = open(log_path, mode='w')
            file.close()

    def reset(self):
        obs = self.darkroom.reset()
        if self.log_path is not None:
            with open(self.log_path, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerows(self.rows)

        self.rows = []
        # return obs.astype(np.float32)
        return obs[0] * self.darkroom.dim + obs[1]

    def step(self, action):
        obs = self.darkroom.get_obs()

        if isinstance(action, np.int64):
            zeros = np.zeros(self.action_space.n)
            zeros[action] = 1
            action = zeros.copy()

        next_obs, reward, done, info = self.darkroom.step(action)
        row = list(obs) + list(action) + [reward, done]
        self.rows.append(row)
        # return next_obs.astype(np.float32), reward, done, info
        return next_obs[0] * self.darkroom.dim + next_obs[1], reward, done, info

    def render(self, mode='human'):
        pass

    def close(self):
        pass


class SaveOnBestTrainingRewardCallback(BaseCallback):
    """
    Callback for saving a model (the check is done every ``check_freq`` steps)
    based on the training reward (in practice, we recommend using ``EvalCallback``).
    :param check_freq: (int)
    :param log_dir: (str) Path to the folder where the model will be saved.
    It must contains the file created by the ``Monitor`` wrapper.
    :param verbose: (int)
    """
    def __init__(self, check_freq: int, log_dir: str, verbose=1):
        super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.log_dir = log_dir
        self.save_path = os.path.join(log_dir, 'best_model')
        self.best_mean_reward = -np.inf

    def _init_callback(self) -> None:
        # Create folder if needed
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        if self.n_calls % self.check_freq == 0:

            # Retrieve training reward
            x, y = ts2xy(load_results(self.log_dir), 'timesteps')
            if len(x) > 0:
                # Mean training reward over the last 100 episodes
                mean_reward = np.mean(y[-100:])
                if self.verbose > 0:
                    print(f"Num timesteps: {self.num_timesteps}")
                    print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")

                # New best model, you could save the agent here
                if mean_reward > self.best_mean_reward:
                    self.best_mean_reward = mean_reward
                    # Example for saving best model
                    if self.verbose > 0:
                        print(f"Saving new best model to {self.save_path}.zip")
                    self.model.save(self.save_path)

            return True


def generate_dr_ppo_histories_for_envs(envs, env_indices, logs_exist=False, limit=1000):
    log_dir = 'ppo_logs_1k_random_init_one_hot'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)

    trajs = []
    if logs_exist:
        if len(env_indices) == 80:
            filepath = 'ppo_datasets_0506/trajs_darkroom_heldout_envs100000_hists1_samples1_H100_d10_train.pkl'
        else:
            filepath = 'ppo_datasets_0506/trajs_darkroom_heldout_envs100000_hists1_samples1_H100_d10_test.pkl'
        print(filepath)

        with open(filepath, 'rb') as f:
            all_loaded_trajs = pickle.load(f)

        for i, (env_index, env) in enumerate(zip(env_indices, envs)):
            loaded_trajs = all_loaded_trajs[i * limit:(i + 1) * limit][:1000]
            print('Env index: ', env_index, 'length of loaded trajs: ', len(loaded_trajs))
            trajs += generate_dr_ppo_histories_for_env(i, env, env_index, log_dir, limit_episodes=1000, loaded_trajs=loaded_trajs)

    else:
        for i, (env_index, env) in enumerate(zip(env_indices, envs)):
            print('Env index: ', env_index, 'generating PPO histories...')
            trajs += generate_dr_ppo_histories_for_env(i, None, env, env_index, log_dir, limit_episodes=limit)
        # num_parallel_jobs = 10
        # for j in range(len(envs) // num_parallel_jobs):
        #     curr_env_indices = env_indices[j * num_parallel_jobs:(j+1) * num_parallel_jobs]
        #     curr_envs = envs[j * num_parallel_jobs:(j+1) * num_parallel_jobs]
        #     print('Processing envs: ', curr_env_indices)
        #     manager = multiprocessing.Manager()
        #     return_dict = manager.dict()
        #     jobs = []
        #     for i, (env_index, env) in enumerate(zip(curr_env_indices, curr_envs)):
        #         p = multiprocessing.Process(
        #             target=generate_dr_ppo_histories_for_env,
        #             args=(i, return_dict, env, env_index, log_dir, limit))
        #         jobs.append(p)
        #         p.start()

        #     for proc in jobs:
        #         proc.join()

        #     for i in range(num_parallel_jobs):
        #         trajs += return_dict[i]

    return trajs


def generate_dr_ppo_histories_for_env(procnum, return_dict, env, env_index, log_dir, limit_episodes, loaded_trajs=None):
    log_dir = os.path.join(log_dir, f'env{env_index}')
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, f'env{env_index}.csv') if loaded_trajs is None else None
    wrapped_env = DarkroomEnvWrapper(env, log_path=None)
    wrapped_env = Monitor(wrapped_env, log_dir)

    limit = env.H * limit_episodes
    # if loaded_trajs is None:
    #     # policy_kwargs = dict(activation_fn=torch.nn.ReLU,
    #     #                      net_arch=dict(pi=[256, 256], vf=[256, 256]))
    #     model = PPO('MlpPolicy', wrapped_env, verbose=1)
    #     callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
    #     model.learn(total_timesteps=limit, callback=[callback])
    # else:
    model = PPO.load(os.path.join(log_dir, 'best_model.zip'), env=wrapped_env)

    if loaded_trajs is None:
        dtypes = [(f'obs{i}', float) for i in range(env.dx)]
        dtypes += [(f'act{i}', float) for i in range(env.du)]
        dtypes = dtypes + [('reward', float), ('done', bool)]
        data = np.genfromtxt(log_path, delimiter=',', dtype=dtypes)

        H = env.H
        obs = np.array([data[f'obs{i}'] for i in range(env.dx)]).T[:limit].reshape(-1, H, env.dx)
        actions = np.array([data[f'act{i}'] for i in range(env.du)]).T[:limit].reshape(-1, H, env.du)
        rewards = data['reward'][:limit].reshape(-1, H, 1)

    num_hists = 10
    H = 100
    trajs = []
    for i in range(limit_episodes):
        # Create rollin data
        if loaded_trajs is None:
            # rollin_xs = obs[i, :-1, :]
            # rollin_us = actions[i, :-1, :]
            # rollin_xps = obs[i, 1:, :]
            # rollin_rs = rewards[i, :-1, :]

            curr_xs = obs[i, :-1, :]
            curr_us = actions[i, :-1, :]
            curr_xps = obs[i, 1:, :]
            curr_rs = rewards[i, :-1, :]

            if i < limit_episodes - 51:
                next_xs = obs[i+50, :-1, :]
                next_us = actions[i+50, :-1, :]
                next_xps = obs[i+50, 1:, :]
                next_rs = rewards[i+50, :-1, :]
            else:
                next_xs = curr_xs
                next_us = curr_us
                next_xps = curr_xps
                next_rs = curr_rs

            start_ts = [np.random.randint(98) for j in range(num_hists)]

            rollin_xs_list = []
            rollin_us_list = []
            rollin_xps_list = []
            rollin_rs_list = []
            xs_list = []
            us_list = []
            for j in range(num_hists):
                start_t = start_ts[j]
                rollin_xs = np.concatenate((curr_xs[start_t:], next_xs[:start_t+1]))
                rollin_us = np.concatenate((curr_us[start_t:], next_us[:start_t+1]))
                rollin_xps = np.concatenate((curr_xps[start_t:], next_xps[:start_t+1]))
                rollin_rs = np.concatenate((curr_rs[start_t:], next_rs[:start_t+1]))
                assert rollin_xs.shape[0] == H
                assert rollin_us.shape[0] == H
                assert rollin_xps.shape[0] == H
                assert rollin_rs.shape[0] == H

                rollin_xs_list.append(rollin_xs)
                rollin_us_list.append(rollin_us)
                rollin_xps_list.append(rollin_xps)
                rollin_rs_list.append(rollin_rs)
                # AD labels
                xs_list.append(next_xs[start_t+1])
                us_list.append(next_us[start_t+1])

        else:
            print("Using loaded rollin data")
            loaded_traj = loaded_trajs[i]
            rollin_xs = loaded_traj['rollin_xs']
            rollin_us = loaded_traj['rollin_us']
            rollin_xps = loaded_traj['rollin_xps']
            rollin_rs = loaded_traj['rollin_rs']

            # print("Collecting random rollin data...")
            # xs, us, xps, rs = [], [], [], []

            # for h in range(100):
            #     x = env.sample_x()
            #     u = env.sample_u()

            #     xp, r = env.transit(x, u)

            #     xs.append(x)
            #     us.append(u)
            #     xps.append(xp)
            #     rs.append(r)
            # rollin_xs = np.array(xs)
            # rollin_us = np.array(us)
            # rollin_xps = np.array(xps)
            # rollin_rs = np.array(rs)

        # # Create query state and action label
        # xs = [env.sample_x() for j in range(num_hists)]
        # discrete_xs = [x[0] * env.dim + x[1] for x in xs]
        # us = model.predict(discrete_xs, deterministic=False)[0]
        xs = xs_list
        us = us_list

        for x, action, rollin_xs, rollin_us, rollin_xps, rollin_rs in zip(
            xs, us, rollin_xs_list, rollin_us_list, rollin_xps_list, rollin_rs_list
        ):
            # u = np.zeros(env.action_space.n)
            # u[action] = 1
            u = action.copy()

            # u = env.opt_a(x)
            traj = {
                'state': x,
                'action': u,
                'rollin_xs': rollin_xs,
                'rollin_us': rollin_us,
                'rollin_xps': rollin_xps,
                'rollin_rs': rollin_rs,
                'goal': env.goal,
            }
            # Add perm_index for DarkroomEnvPermuted
            if hasattr(env, 'perm_index'):
                traj['perm_index'] = env.perm_index

            trajs.append(traj)

    return trajs
    # return_dict[procnum] = trajs
