# exp_names = ['debug_r3']
# exp_names = ['debug_r3_p1']
# exp_names = ['iclr_v_random']
# exp_names = ['smec_lunch_r5']
# exp_names = ['smec_lunch_00_59_r3']
exp_names = ['iclr_policy1_lunch0006']
num_step = 128
legend_names = ['RLLift']

import matplotlib.pyplot as plt
import os
import numpy as np

figure_dir = f'../train_figures/{"-".join(exp_names)}'
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
if not os.path.exists(figure_dir):
    os.makedirs(figure_dir)

mean_interval = 10

rewards = []
losses = []
alosses = []
accs = []
for exp_name in exp_names:
    with open(f'../train_log/{exp_name}.log', 'r') as f:
        with open('draw_tmp_log', 'w') as wf:
            l_idx = 0
            for l in f:
                l_idx += 1
                if l_idx > 3:
                # if l_idx > 322:
                    wf.write(l)

    with open('draw_tmp_log', 'r') as f:
        epoch = 0
        reward = []
        episode_reward = []
        loss = []
        aloss = []
        test_acc = []
        interval_i = 0
        for l in f:
            l = l.split()
            # print(l)
            if 'Mean' in l and '[train]' in l:
                # print(l)
                reward.append(float(l[-1][:-1]))
                # episode_reward.append(float(l[-1][:-1]))

            # if 'Finish' in l:
            #     reward.append(sum(episode_reward))
            #     episode_reward = []

            if 'loss:' in l:
                loss.append(float(l[8][:-1]))  # value loss
                aloss.append(float(l[11][:-1]))  # action loss
                # losses[-1].append(float(l[11][:-1]))  # action loss

            if '[Evaluation]' in l and 'Curr' in l:
                test_acc.append(float(l[4][:-1]))  # action loss

        rewards.append(reward)
        losses.append(loss)
        alosses.append(aloss)
        accs.append(test_acc)

COLORS = ['orangered', 'forestgreen',  'purple', 'dodgerblue',  'magenta', 'salmon','lavender', 'turquoise','tan','lime', 'teal',  'lightblue',
          'darkgreen',   'gold',   'darkblue', 'purple','brown','orange', ]


# loss
plt.figure()
plt.title('loss vs iteration')
# plt.title('损失与迭代次数的关系')
for i, loss in enumerate(losses):
    # loss = loss[50:4050]
    new_loss = [np.mean(loss[max(i - (mean_interval - 1), 0):i+1]) for i in range(len(loss))]
    xs = np.arange(0, len(loss) * num_step, num_step)
    plt.plot(xs, loss, color=COLORS[i], alpha=0.3)
    plt.plot(xs, new_loss, label='value loss', color=COLORS[i])

    aloss = alosses[i]
    new_aloss = [np.mean(aloss[max(i - (mean_interval - 1), 0):i+1]) for i in range(len(aloss))]
    plt.plot(xs, aloss, color=COLORS[i+4], alpha=0.3)
    plt.plot(xs, new_aloss, label='action loss', color=COLORS[i+4])
plt.grid(True)
plt.legend()
plt.xlabel('iteration')
plt.ylabel('loss')
# plt.xlabel('迭代次数')
# plt.ylabel('评测时间')

plt.savefig(f'{figure_dir}/mean_loss.pdf')
plt.show()
plt.close()

# reward
plt.figure()
plt.title('reward vs iteration')
# plt.title('奖励与迭代次数的关系')
for i, reward in enumerate(rewards):
    # reward = reward[:1000]
    # new_reward = [np.mean(reward[i:i+50]) for i in range(len(reward)-50)]
    new_reward = [np.mean(reward[max(i - (mean_interval - 1), 0):i+1]) for i in range(len(reward))]
    plt.plot( reward, color=COLORS[i], alpha=0.3)
    plt.plot(new_reward, label=legend_names[i], color=COLORS[i])
    # plt.plot(test_accs[:length], color='blue')
plt.grid(True)
plt.legend()
plt.xlabel('iteration')
plt.ylabel('reward')
# plt.xlabel('迭代次数')
# plt.ylabel('奖励')

plt.savefig(f'{figure_dir}/mean_reward.pdf')
plt.show()
plt.close()


# acc
plt.figure()
plt.title('acc vs iteration')
# plt.title('损失与迭代次数的关系')
for i, acc in enumerate(accs):
    xs = np.arange(0, len(acc) * num_step, num_step)
    plt.plot(xs, acc, label=legend_names[i], color=COLORS[i])
    # plt.plot(test_accs[:length], color='blue')
plt.grid(True)
plt.legend()
plt.xlabel('iteration')
plt.ylabel('awt')
# plt.xlabel('迭代次数')
# plt.ylabel('评测时间')
plt.savefig(f'{figure_dir}/eval_awt.pdf')
plt.show()
plt.close()




