import os

import gymnasium as gym
import stable_baselines3.common.vec_env
from stable_baselines3.common.env_util import make_vec_env

import yaml
from importlib import import_module
from src.model_based_agents.CAPPO import CAPPO
from src.model_based_agents.CASAC import CASAC
from matplotlib import animation
import matplotlib.pyplot as plt
from minigrid.wrappers import *
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.utils import obs_as_tensor
import torch as th
import sys


def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def wrap_env(env,wrappers):
    for wrapper_class_str in wrappers:
        module_name, class_name = wrapper_class_str.rsplit('.', 1)
        module = import_module(module_name)
        wrapper_class = getattr(module, class_name)
        if wrapper_class_str == "stable_baselines3.common.vec_env.DummyVecEnv" or wrapper_class_str == "stable_baselines3.common.vec_env.VecEnv":
            env = wrapper_class([lambda: env])
        elif wrapper_class_str == "stable_baselines3.common.vec_env.VecNormalize":
            env = wrapper_class(env, norm_reward=False)
        else:
            env = wrapper_class(env)
    return env


def save_frames_as_gif(frames, path='./', filename='gym_animation.gif'):

    #Mess with this to change frame size
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)

    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    anim.save(path + filename, writer='imagemagick', fps=60)
    plt.close()


def load_env(env_name="DynamicObstaclesSwitch-6x6-v0",render_mode='rgb_array'):
    config = load_config('parameters_cappo.yml')
    params = config[env_name]['parameters']

    env_name = config[env_name]['environment']['name']
    env = gym.make(env_name, render_mode=render_mode)

    # Apply wrappers
    try:
        wrapper_classes = config[env_name]['wrappers']
        env = wrap_env(env, wrapper_classes)
    except KeyError:
        pass
    return env


def load_agent(model_file, env_file, game, algorithm, seed):

    if algorithm == "CAPPO":
        config_file = 'parameters_cappo.yml'
    elif algorithm == "CASAC":
        config_file = 'parameters_casac.yml'
    else:
        raise ValueError("Invalid algorithm name")

    config = load_config(config_file)
    params = config[game]['parameters']
    try:
        if "activation_fn" in params["policy_kwargs"]:
            params["policy_kwargs"]["activation_fn"] = eval(params["policy_kwargs"]["activation_fn"])
    except:
        pass
    if game == "roundabout-v0":
        env_config = {'config': config[game]['environment']['config']}
    else:
        env_config = {}
    # Create environment
    env_name = config[game]['environment']['name']
    render_mode = config[game]['environment']['render_mode']
    try:
        num_envs = config[game]['environment']['num_envs']
    except KeyError:
        num_envs = 1
    if 'policy_kwargs' not in params:
        params['policy_kwargs'] = {}
    if game in ["roundabout-v0", "highway-fast-v0", "DynamicObstaclesSwitch-8x8-v0", "SlipperyDistShift-v0"]:

        env = make_vec_env(game, seed=seed, n_envs=num_envs, vec_env_cls=DummyVecEnv,
                           wrapper_class=gym.wrappers.flatten_observation.FlattenObservation,
                           env_kwargs=env_config)
    else:
        # Apply wrappers
        env = make_vec_env(game, seed=seed, n_envs=num_envs, vec_env_cls=DummyVecEnv,
                           # wrapper_class=gym.wrappers.flatten_observation.FlattenObservation,
                           env_kwargs=env_config)
    if 'stable_baselines3.common.vec_env.VecNormalize' in config[game]['wrappers']:
        env = VecNormalize(env, norm_reward=False)

    if isinstance(env, VecNormalize):
        env = VecNormalize.load(env_file, env.unwrapped)

    if algorithm == "CAPPO":
        model = CAPPO.load(model_file, device="cpu", env=env)
    elif algorithm == "CASAC":
        model = CASAC.load(model_file, device="cpu", env=env)
    else:
        raise ValueError("Invalid algorithm name")
    obs = env.reset()
    frames = []

    observations = []
    trajectory = []
    data = []
    entropies = []
    rewards = []
    rew = 0
    i = 0
    while i < 10:
        # frames.append(env.render())
        # vec_env.render(mode='human')
        action, _states = model.predict(obs, deterministic=True)
        prediction = model.dynamic_model.predict_mean(obs_as_tensor(obs, model.device).float(), th.as_tensor(action,device=model.device).float()).detach().cpu().numpy()
        entropies.append(np.log2(np.abs(prediction-obs).sum()))
        if isinstance(env, VecNormalize):
            obs = env.unnormalize_obs(obs)
        data.append(obs)
        trajectory.append(obs)
        obs, reward, done, info = env.step(action)
        rew += reward
        if done[0]:
            i+=1
            rewards.append(rew)
            rew = 0
            obs = env.reset()
            observations.append(trajectory)
            trajectory = []
    env.close()
    # Save frames as a gif
    # save_frames_as_gif(frames, path='./', filename=game+'_ppo.gif')
    # Print summary
    print("Results for ",model_file, "with k=", str(model.cm_w))
    print("Entropy: ",np.mean(entropies), "+-", np.std(entropies))
    print("Rewards: ",np.mean(rewards), "+-", np.std(rewards))



if __name__ == '__main__':
    # Get folder from command line
    folder = sys.argv[1]
    path = folder
    # files in folder ending in .zip are models
    files = os.listdir(path)
    models = sorted([f for f in files if f.endswith('.zip')])
    env_files = sorted([f for f in files if f.endswith('.pkl')])
    for mi, model_file in enumerate(models):
        game = model_file.split('_')[0]
        algorithm = model_file.split('_')[1]
        seed = int(model_file.split('_')[2].split('.')[0])
        load_agent(os.path.join(folder,model_file), os.path.join(folder,env_files[mi]),game,algorithm,seed)