import os
import json
import tqdm
import torch
import random
import pprint
import numpy as np
from config import Config
from RetroEnv import RetroEnv
from RetroAgent import RetroAgentPPO
from rdkit.Chem import MolFromSmiles, MolFromSmarts


def eval_success_rate(retro_env, agent, RawVal, additional_mask):
    test = random.sample(RawVal, 1000)
    success_list = []
    for idx_data, data_dict in tqdm.tqdm(enumerate(test), leave=False):
        state = retro_env.reset(data_dict=data_dict)
        done = 0
        total_reward = 0
        while not done:
            action, p_of_action, rst_policy = agent.select_action_for_inference(state, additional_mask=additional_mask)
            r, next_state, done = retro_env.step(action)
            state = next_state
            total_reward += r
            if done:
                if set(state[3]) == set(retro_env.label_RcNodeIdx) and set(state[4]) == set(retro_env.label_LgIdx):
                    success_list.append(1)
                else:
                    success_list.append(0)
    success_rate = np.mean(success_list)
    return success_rate


def main(args, have_imitation=True, ckpt=None):
    print('configuration: ')
    pprint.pprint(args.to_dict())
    print('----------------------------------')

    dataset_name = args.dataset_name
    raw_path = os.path.join(args.RawDataFile_path, dataset_name)
    processed_path = os.path.join(args.ProcessedDataFile_path, dataset_name)

    print('start: process data')
    config_name = args.config_name
    save_root_path = os.path.join('./outputs', config_name)
    if not os.path.exists(save_root_path):
        os.makedirs(save_root_path)
    # save config
    args.save(os.path.join(save_root_path, config_name))

    with open(os.path.join(processed_path, 'raw_train.json'), 'r') as f:
        RawTrain = json.load(f)
    with open(os.path.join(processed_path, 'raw_val.json'), 'r') as f:
        RawVal = json.load(f)

    # additional_mask
    with open(os.path.join(processed_path, 'hypergraph.json'), 'r') as f:
        hg = json.load(f)
    with open(os.path.join(processed_path, 'additional_mask_for_zero.json'), 'r') as f:
        additional_mask_for_zero = json.load(f)
    additional_mask = torch.ones(hg['num_v'] + 1, dtype=torch.float32)
    additional_mask[additional_mask_for_zero] = 0
    additional_mask[-1] = 0
    additional_mask = additional_mask.reshape(1, -1)
    print('end: process data')

    device = torch.device(args.device)
    additional_mask = additional_mask.to(device)
    retro_env = RetroEnv(data_name=dataset_name, RawDataFile_path=args.ProcessedDataFile_path,
                         max_len=args.trajectory_max_length)
    agent = RetroAgentPPO(data_name=dataset_name, ProcessedDataFile_path=args.ProcessedDataFile_path,
                          hidden_dimension=args.hidden_dimension, num_egat_heads=args.num_egat_heads,
                          num_egat_layers=args.num_egat_layers, num_of_LayerHypergraph=args.num_of_LayerHypergraph,
                          residual=args.residual, have_fp=args.have_fp, have_structure=args.have_structure,
                          gamma=args.gamma, eps_clip=args.eps_clip, value_coefficient=args.value_coefficient,
                          entropy_coefficient=args.entropy_coefficient, learning_rate=args.learning_rate, device=device)
    min_num_transitions = args.min_num_transitions
    num_epochs = args.num_epochs
    batch_size = args.batch_size

    # imitation learning
    if have_imitation:
        print('start: imitation learning')
        # imitation learning
        learning = 0
        for n_epi in range(args.num_imitation):
            random.shuffle(RawTrain)
            loss_list = []
            print('{}-epoch imitation start!***************************'.format(n_epi))
            for idx_data, data_dict in tqdm.tqdm(enumerate(RawTrain), leave=False):

                state = retro_env.reset(data_dict=data_dict)
                done = 0
                total_reward = 0

                gt_idx = 0
                gt = retro_env.generate_random_gt_trajectory()

                while not done:
                    action, p_of_action, rst_policy = agent.select_action(state)

                    action = gt[gt_idx]
                    gt_idx += 1
                    p_of_action = rst_policy[action].item()

                    r, next_state, done = retro_env.step(action)
                    trans = (state, action, p_of_action, r, next_state, done)
                    # (state, action, p_of_a, r, next_state, done)  # ([], int, float, float, [], int)
                    agent.store_transition(trans)
                    state = next_state
                    total_reward += r

                if len(agent.buffer) >= min_num_transitions:
                    print('------------------------------------')
                    print('start {}-th imitation learning at {} data in epoch {}.'.format(learning, idx_data, n_epi))

                    mean_loss = agent.update(num_epochs=num_epochs, batch_size=batch_size, imitation=True)
                    loss_list.append(mean_loss)

                    save_imitation_path = os.path.join(save_root_path, 'retro_imitation.pkl')
                    agent.save_param(save_imitation_path)

                    learning += 1
                    print('{}-th learning end, loss is {:.2f}.'.format(learning, mean_loss))
                    agent.clear_buffer()

            print('{}-epoch end! final loss_list: {} ***************************'.format(n_epi, np.mean(loss_list)))
        print('end: imitation learning')

    # ppo learning
    print('start: ppo learning')
    max_success_rate = -1
    learning = 0
    if ckpt is not None:
        agent.load_param(ckpt)
    for n_epi in range(args.total_num_epoch):
        random.shuffle(RawTrain)
        success_list = []
        loss_list = []
        print('{}-epoch start!***************************'.format(n_epi))
        for idx_data, data_dict in tqdm.tqdm(enumerate(RawTrain), leave=False):

            state = retro_env.reset(data_dict=data_dict)
            done = 0
            total_reward = 0

            while not done:
                action, p_of_action, rst_policy = agent.select_action(state)

                r, next_state, done = retro_env.step(action)
                trans = (state, action, p_of_action, r, next_state, done)
                # (state, action, p_of_a, r, next_state, done)  # ([], int, float, float, [], int)
                agent.store_transition(trans)
                state = next_state
                total_reward += r
                if done:
                    if set(state[3]) == set(retro_env.label_RcNodeIdx) and set(state[4]) == set(retro_env.label_LgIdx):
                        success_list.append(1)
                        print('success!!!!!!!!!!!!!!!!!')
                    else:
                        success_list.append(0)

            if len(agent.buffer) >= min_num_transitions:
                print('------------------------------------')
                print('start {}-th learning at {} data in epoch {}.'.format(learning, idx_data, n_epi))

                mean_loss = agent.update(num_epochs=num_epochs, batch_size=batch_size, imitation=False)
                loss_list.append(mean_loss)

                save_ppo_path = os.path.join(save_root_path, 'retro_ppo.pkl')
                agent.save_param(save_ppo_path)

                print('{}-th learning end, loss is {:.2f}.'.format(learning, mean_loss))
                print('recent 32 final success reward: {:.2f}'.format(np.mean(success_list[-32:])))
                print('test max_success_rate: ', max_success_rate)
                agent.clear_buffer()

                if learning % args.save_at_num_update == 0:
                    success_rate = eval_success_rate(retro_env=retro_env, agent=agent,
                                                     RawVal=RawVal, additional_mask=additional_mask)
                    if success_rate > max_success_rate:
                        max_success_rate = success_rate
                        save_best_path = os.path.join(save_root_path, 'retro_best.pkl')
                        agent.save_param(save_best_path)
                        print('update best!!!!!!!!!!!!!!!', 'max_success_rate: ', max_success_rate)
                learning += 1

        print('{}-epoch end! success rate: {} loss: {} *******************'.format(n_epi, np.mean(success_list),
                                                                                   np.mean(loss_list)))
        print('test max_success_rate: ', max_success_rate, '*******************')
    print('end: ppo learning')


if __name__ == '__main__':
    args = Config(config_path='./ConfigFile/USPTO-50k.json')
    main(args, have_imitation=True)
