import time
import matplotlib as mpl
mpl.use('Agg')
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import os
import os.path as osp
import torch
import numpy as np
from spinup import EpochLogger
import json
import yaml
from safe_control_gym.utils.registration import make
import csv


def load_pytorch_policy(fpath, itr, deterministic=False):
    """ Load a pytorch policy saved with Spinning Up Logger."""
    fname = osp.join(fpath, 'pyt_save', 'model'+itr+'.pt')
    print('\n\nLoading from %s.\n\n'%fname)

    model = torch.load(fname).to("cpu")
    print("model", model)
    # make function for producing an action given a single state
    def get_action(x):
        with torch.no_grad():
            x = torch.as_tensor(x, dtype=torch.float32)
            if 'cup' in fpath:
                action = model.step(x, deterministic)[0]
            else:
                action = model.act(x)
            
        return action

    return get_action, model

def load_policy_and_env(fpath, itr='last', deterministic=False):


    backend = 'pytorch'

    # handle which epoch to load from
    if itr=='last':
        pytsave_path = osp.join(fpath, 'pyt_save')
        saves = [int(x.split('.')[0][5:]) for x in os.listdir(pytsave_path) if len(x)>8 and 'model' in x]
        itr = '%d'%max(saves) if len(saves) > 0 else ''

    else:
        assert isinstance(itr, int), \
            "Bad value provided for itr (needs to be int or 'last')."
        itr = '%d'%itr

    # load the get_action function
    get_action, model = load_pytorch_policy(fpath, itr, deterministic)

    # try to load environment from save
    try:
        state = joblib.load(osp.join(fpath, 'vars'+itr+'.pkl'))
        env = state['env']
    except:
        env = None
    return env, get_action, model

def run_policy(logger, env, get_action, model, epsilon=0.00, nums_ep=10, cbf=False):
    o, r, d, ep_ret, ep_risk, ep_len, n = env.reset()[0], 0, False, 0, 0, 0, 0

    avg_ret = 0.0
    avg_risk = 0.0 
    avg_len = 0.0 
    while n < nums_ep:
        a = get_action(torch.as_tensor(o, dtype=torch.float32).to("cpu"))

        o, r, d, infos= env.step(a)
        c = infos['constraint_violation']

        ep_ret += r
        ep_risk += c
        ep_len += 1

        if d or (ep_len == 250):
            if ep_len <= 10:
                o, r, d, ep_ret, ep_len, ep_risk = env.reset()[0], 0, False, 0, 0, 0
                continue
            else:
                ep_risk = ep_risk / ep_len
                avg_ret += ep_ret
                avg_risk += ep_risk
                avg_len += ep_len
                logger.store(EpRet=ep_ret, EpRisk=ep_risk, EpLen=ep_len)
                print('Epsilon %.3f \t EpRet %.3f \t EpRisk %.3f \t EpLen %d'%(epsilon, ep_ret, ep_risk, ep_len))
                o, r, d, ep_ret, ep_len, ep_risk = env.reset()[0], 0, False, 0, 0, 0
                n += 1

    logger.log_tabular('EpRet', average_only=True)
    logger.log_tabular('EpRisk', average_only=True)
    logger.log_tabular('EpLen', average_only=True)
    logger.dump_tabular()

    return avg_ret/nums_ep, avg_risk/nums_ep, avg_len/nums_ep

def get_file_path(output_dir, filename):
    if not osp.exists(output_dir):
        os.makedirs(output_dir)
    base, extension = osp.splitext(filename)
    file_path = osp.join(output_dir, f"{base}_v0{extension}")
    if osp.exists(file_path):
        i = 1
        while osp.exists(file_path):
            file_path = osp.join(output_dir, f"{base}_v{i}{extension}")
            i += 1
    return file_path

def save_plot(file_path, figure=None):
    if figure is not None:
        figure.savefig(file_path, dpi=300)

if __name__ == '__main__':
    fpath_list = {
                'PPOL-CartPole-Stab': './src/data/CartPole-stab/ppolag/ppolag_seed_0_v0',
                'FuzPPOL-CartPole-Stab': './src/data/CartPole-stab/fuzppolag/fuzppolag_seed_0_v0',
                'CPPO-CartPole-Stab': './src/data/CartPole-stab/cppo/cppo_seed_0_v0',
                'FuzCPPO-CartPole-Stab': './src/data/CartPole-stab/fuzcppo/fuzcppo_seed_0_v0',
                'CUP-CartPole-Stab': './src/data/CartPole-stab/cup/cup_seed_0_v0',
                'FuzCUP-CartPole-Stab': './src/data/CartPole-stab/fuzcup/fuzcup_seed_0_v0',
                }

    parser = argparse.ArgumentParser()
    parser.add_argument('--alg_list', nargs='+', default=['PPOL', 'FuzPPOL'], help='List of algorithms to test.')
    parser.add_argument('--agent_list', nargs='+', default=['CartPole'], help='List of agents to test.')
    parser.add_argument('--task_list', nargs='+', default=['Stab'], help='List of tasks to test.')
    parser.add_argument('--disturb_part', type=str, default='dynamics', choices=['dynamics', 'observation', 'action'], help='Part of the system to which disturbances are applied.')
    parser.add_argument('--disturb_type', type=str, default='white_noise', choices=['white_noise', 'impulse', 'periodic'], help='Type of disturbances applied to the system.')
    parser.add_argument('--test_seed', type=int, nargs='+', default=[i for i in range(10)], help='Seed numbers for testing.')
    parser.add_argument('--nums_ep', type=int, default=10, help='Number of episodes to run each policy.')

    args = parser.parse_args()
    alg_list = args.alg_list
    agent_list = args.agent_list
    task_list = args.task_list
    disturb_part = args.disturb_part
    disturb_type = args.disturb_type
    test_seed_list = args.test_seed
    nums_ep = args.nums_ep

    test_task = {}
    for alg in alg_list:
        for agent in agent_list:
            for task in task_list:
                key_name = alg + '-' + agent + '-' + task
                if key_name in fpath_list:
                    test_task[key_name] = fpath_list[key_name]
    print(f"Test: {test_task.keys()}")
    test_eps_list = [-0.02* i for i in range(6)] + [0.02* (i+1) for i in range(5)]
    print("test_eps_list", test_eps_list)
    logger = EpochLogger()
    alg_data = {
                seed:{
                    alg: {
                        eps: {
                            'rewards': [],
                            'risks': [],
                            'lengths': []
                        } for eps in test_eps_list
                    } for alg in test_task
                } for seed in test_seed_list
            }

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    for i, alg in enumerate(test_task):
        base_alg, agent, task  = alg.split('-')

        env, get_action, model = load_policy_and_env(fpath_list[alg], deterministic=True)

        yaml_config_path = f'./envs/safe_control_gym/config_overrides/{str(agent).lower()}/{str(agent).lower()}_{str(task).lower()}.yaml'
        with open(yaml_config_path, 'r') as file:
            yaml_config = yaml.safe_load(file)
        for test_seed in test_seed_list:
            print(f"*********** SEED: {test_seed}***********")
            for j, eps in enumerate(test_eps_list):

                if disturb_type == 'impulse':
                    yaml_config['task_config']['disturbances'] = {
                    disturb_part:[{
                        'disturbance_func': disturb_type,
                        'magnitude': eps * 10,
                        'step_offset': 20,
                        'duration': 80,
                        'decary_rate': 0.9, 
                        }]
                }
                elif disturb_type == 'white_noise':
                    yaml_config['task_config']['disturbances'] = {
                        disturb_part:[{
                            'disturbance_func': disturb_type,
                            'std': abs(eps),
                            }]
                    }
                elif disturb_type == 'periodic':
                    yaml_config['task_config']['disturbances'] = {
                        disturb_part:[{
                            'disturbance_func': disturb_type,
                            'scale': abs(eps),
                            'frequency': 1,
                            }]
                    }
                env = make(f'{str(agent).lower()}', **yaml_config['task_config'])
                env.seed(test_seed)
                env.reset()
                print(f"Now we test: {alg}")
                ep_ret, ep_risk, ep_len = run_policy(logger, env, get_action, model, eps, nums_ep=nums_ep)
                alg_data[test_seed][alg][eps]['rewards'].append(ep_ret)
                alg_data[test_seed][alg][eps]['risks'].append(ep_risk)
                alg_data[test_seed][alg][eps]['lengths'].append(ep_len)

    output_dir = f"./imgs/{agent_list[0]}-{task_list[0]}/test_{disturb_part}"

    def plot_data(all_seed_data, alg_list, test_eps_list, output_dir, filename):
        fig, axs = plt.subplots(1, 3, figsize=(18, 5))
        
        stats = ['rewards', 'risks', 'lengths']
        ylabels = ['Average Reward', 'Average Risk', 'Average Length']

        for i, stat in enumerate(stats):
            for alg in alg_list:
                eps_vals = sorted(test_eps_list)
                mean_values = []
                std_values = []
                for eps in eps_vals:
                    ep_stats = []
                    for seed_data in all_seed_data.values():
                        ep_stats.extend(seed_data[alg][eps][stat])
                    mean_values.append(np.mean(ep_stats))
                    std_values.append(np.std(ep_stats))
                    
                min_values = [m - s for m, s in zip(mean_values, std_values)]
                max_values = [m + s for m, s in zip(mean_values, std_values)]
                
                axs[i].fill_between(eps_vals, min_values, max_values, alpha=0.3)
                axs[i].plot(eps_vals, mean_values, label=alg)
                
            axs[i].set_xlabel('Epsilon')
            axs[i].set_ylabel(ylabels[i])
            axs[i].set_title(f'{ylabels[i]} at Different Epsilons')
            axs[i].legend()
            
        plt.tight_layout()
        file_path = get_file_path(output_dir, filename)
        save_plot(file_path, figure=fig)

    
    filename = f"{agent_list[0]}_{task_list[0]}_disturb_{disturb_part}_{disturb_type}_compare.png"
    plot_data(alg_data, test_task, test_eps_list, output_dir, filename)

    def save_results_to_csv(all_seed_data, alg_list, test_eps_list, output_dir, filename):    
        csv_file_path = get_file_path(output_dir, filename)
        
        with open(csv_file_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            
            header = ['Seed', 'Alg', 'Epsilon', 'AvgRet', 'AvgRisk', 'AvgLen']
            writer.writerow(header)
            
            for seed in test_seed_list:
                for alg in alg_list:
                    for eps in test_eps_list:
                        avg_reward = np.mean(all_seed_data[seed][alg][eps]['rewards'])
                        avg_risk = np.mean(all_seed_data[seed][alg][eps]['risks'])
                        avg_length = np.mean(all_seed_data[seed][alg][eps]['lengths'])
                        
                        row = [seed, alg, eps, avg_reward, avg_risk, avg_length]
                        writer.writerow(row)

    filename_csv = f"{agent_list[0]}_{task_list[0]}_disturb_{disturb_part}_{disturb_type}_compare.csv"
    save_results_to_csv(alg_data, test_task, test_eps_list, output_dir, filename_csv)
