import matplotlib.pyplot as plt
import json
import os
import argparse
import rlhf.src.train_behavior_model_new as train_behavior_model_new
import torch
import numpy as np
def load(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data['true_pfms'], data['sim_pfms']

def behavior_clone(env, traj, seed):
    behavior_model = train_behavior_model_new.behavior_model(env_name=env, ac_kwargs=dict(hidden_sizes=[64] * 3))
    behavior_path = '../behavior/%s_%s_%s_%s_%d_%d' % (env, traj, 'uniform', 'regular', 100000, seed)
    if not os.path.isfile(behavior_path):
        raise Exception('behavior clone not done')
    else:
        behavior_model.action_nn.load_state_dict(torch.load(behavior_path))
    test_result = behavior_model.test()
    return test_result

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Walker2d-v3')
    parser.add_argument("--learn", type=str, default='behavior')
    parser.add_argument("--traj", type=str, default='expert')
    parser.add_argument("--proc", type=str, default='pfm')
    parser.add_argument("--data", type=int, default=None)
    args = parser.parse_args()

    if args.proc == 'pfm':
        best_pfm = []
        for seed in range(5):
            if args.data is not None:
                path = '../result_%d/%s_%s_%s_30/outputs_%d.json' % (args.data, args.learn, args.env, args.traj, seed)
            else:
                path = '../result/%s_%s_%s/outputs_%d.json' % (args.learn, args.env, args.traj, seed)
            if os.path.exists(path):
                true_pfms, sim_pfms = load(path)
                best_pfm.append(max(true_pfms))
        print('learn-env-traj:', args.learn, args.env, args.traj, 'pfm:', sum(best_pfm)/len(best_pfm))

    elif args.proc == 'bc_pfm':
        bc_pfm = []
        for seed in range(5):
            bc_pfm.append(behavior_clone(args.env,args.traj,seed))
        print('env-traj:', args.env, args.traj, 'bc_pfm:', sum(bc_pfm) / len(bc_pfm))

    elif args.proc == 'plot':
        results = []
        init, count = False, 0
        for seed in range(5):
            if args.data is not None:
                path = '../result_%d/%s_%s_%s/outputs_%d.json' % (args.data, args.learn, args.env, args.traj, seed)
            else:
                path = '../result/%s_%s_%s/outputs_%d.json' % (args.learn, args.env, args.traj, seed)

            if os.path.exists(path):
                true_pfms, sim_pfms = load(path)
                count += 1
                if not init:
                    true_pfm_mean = np.array(true_pfms)
                    sim_pfm_mean = np.array(sim_pfms)
                else:
                    true_pfm_mean += np.array(true_pfms)
                    sim_pfm_mean += np.array(sim_pfms)

        true_pfm_mean /= count
        sim_pfm_mean /= count

        plt.plot(true_pfm_mean)
        plt.xlabel('Training Epoch')
        plt.ylabel('Actual Performance')
        if args.data is not None:
            dir = '../figures_%d/%s_%s_%s' % (args.data, args.learn, args.env, args.traj)
        else:
            dir = '../figures/%s_%s_%s' % (args.learn, args.env, args.traj)
        if not os.path.exists(dir):
            os.mkdir(dir)
        plt.savefig(os.path.join(dir, 'true_training_log.png'))
        plt.close()

        plt.plot(sim_pfm_mean)
        plt.xlabel('Training Epoch')
        plt.ylabel('Evaluated Performance')
        if args.data is not None:
            dir = '../figures_%d/%s_%s_%s' % (args.data, args.learn, args.env, args.traj)
        else:
            dir = '../figures/%s_%s_%s' % (args.learn, args.env, args.traj)
        plt.savefig(os.path.join(dir, 'sim_training_log.png'))
        plt.close()




