from tqdm import trange
import copy
import pandas as pd
import pickle
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import os
from Envs.vec_env.envs import make_vec_envs
from models.ppo.model import Policy
from models.ppo.utils import get_vec_normalize
from models.ppo.storage import RolloutStorage
from models.ppo import algo
import time
from collections import deque
from utils import get_scheduler
import glob
from dataset import loadEnvData
from shutil import copyfile
from supervisedImg import train_with_representation
from RSI3.pretext_RSI3 import trainRepresentation
import torch.optim as optim
import gym

from cfg import main_config, gym_register


if __name__ == '__main__':
    config=main_config()
    gym_register(config)
    if config.RLManualControl: # used for debugging the env
        envs = make_vec_envs(env_name=config.RLEnvName,
                             seed=0,
                             num_processes=1,
                             gamma=None,
                             device=None,
                             randomCollect=False,
                             config=config)
        pretextModel=config.pretextModel(config).cuda().eval()
        weight_path = config.pretextModelLoadDir
        pretextModel.load_state_dict(torch.load(weight_path))
        print('Load weights for pretextModel from', weight_path)
        envs.venv.pretextModel = pretextModel
        observation = envs.reset()
        for episode in range(10):

            for i in range(config.RLEnvMaxSteps):
                print('step:', i)
                print('step reward', envs.venv.origStepReward)
                envs.render()
                action = torch.zeros(config.RLActionDim)  # dummy action. True action is decided in env
                observation, _, _, _ = envs.step(action)

    else:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("Using device:", device)
        cudnn.benchmark = True
        torch.cuda.empty_cache()

        representationEpoch = config.pretextEpoch  # # counting the number of epochs used to train the representation
        pretextModel = config.pretextModel(config).cuda().eval()  # this model will be trained using triplet loss
        weight_path = config.pretextModelLoadDir
        pretextModel.load_state_dict(torch.load(weight_path))
        print('Load weights for pretextModel from', weight_path)

        if config.RLTrain:
            torch.set_num_threads(1)
            torch.manual_seed(config.RLEnvSeed)
            torch.cuda.manual_seed_all(config.RLEnvSeed)

            if not os.path.exists(config.RLModelSaveDir):
                os.makedirs(config.RLModelSaveDir)
            copyfile(os.path.join('..','Envs', config.envFolder, 'RSI3','config.py'),
                     os.path.join(config.RLModelSaveDir, 'config.py'))

            if hasattr(config, 'trainPretrainedModel'):
                if config.usePretrainedModel and config.trainPretrainedModel:
                    train_with_representation(device=device, pretextModel=pretextModel,
                                              dataPath=config.pretextDataDir,
                                              model_save_dir=config.pretrainedModelSaveDir,
                                              batch_size=config.pretrainedBatchSize,
                                              config=config)

            envs = make_vec_envs(env_name=config.RLEnvName,
                                 seed=config.RLEnvSeed,
                                 num_processes=config.RLNumEnvs,
                                 gamma=config.RLGamma,
                                 device=device,
                                 randomCollect=False,
                                 config=config)

            actor_critic = Policy(
                envs.venv.observation_space.spaces,
                envs.action_space,
                config=config,
                base=config.RLPolicyBase,
                base_kwargs={'recurrent': config.RLRecurrentPolicy,
                             'recurrentInputSize': config.RLRecurrentInputSize,
                             'recurrentSize': config.RLRecurrentSize,
                             'actionHiddenSize': config.RLActionHiddenSize
                             })
            actor_critic.to(device)

            # load pretrained cnn
            if hasattr(config, 'trainPretrainedModel'):
                if config.usePretrainedModel:
                    actor_critic.base.load_state_dict(torch.load(config.pretrainedModelLoadDir), strict=False)
                    print('Loaded pretrained weights from', config.pretrainedModelLoadDir)

            if config.RLModelFineTune:
                print("Load the weights from", config.RLModelLoadDir)
                actor_critic.load_state_dict(torch.load(config.RLModelLoadDir))

            agent = algo.PPO(
                actor_critic,
                config.ppoClipParam,
                config.ppoEpoch,
                config.ppoNumMiniBatch,
                config.ppoValueLossCoef,
                config.ppoEntropyCoef,
                lr=config.RLLr,
                eps=config.RLEps,
                max_grad_norm=config.RLMaxGradNorm,
                config=config)

            rollouts = RolloutStorage(config.ppoNumSteps, config.RLNumEnvs,
                                      envs.venv.observation_space.spaces, envs.action_space,
                                      actor_critic.recurrent_hidden_state_size, config=config)

            env_rewards = np.zeros([config.RLNumEnvs,])
            episode_rewards = deque(maxlen=10)

            envs.venv.pretextModel=pretextModel
            envs.venv.fileNum = len(glob.glob1(os.path.join(config.pretextDataDir[0], 'train'), "*.pickle"))

            print('Begin RL training')
            obs = envs.reset()

            if isinstance(rollouts.obs, dict):
                for key in rollouts.obs:
                    rollouts.obs[key][0].copy_(obs[key])
            else:
                rollouts.obs[0].copy_(obs)
            rollouts.to(device)

            start = time.time()
            num_updates = int(
                config.RLTotalSteps) // config.ppoNumSteps // config.RLNumEnvs
            for j in range(0, num_updates):
                for step in range(config.ppoNumSteps):
                    # Sample actions
                    with torch.no_grad():
                        if isinstance(rollouts.obs, dict):
                            rollouts_obs = {}
                            for key in rollouts.obs:
                                rollouts_obs[key] = rollouts.obs[key][step]
                            value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                                rollouts_obs, rollouts.recurrent_hidden_states[step],
                                rollouts.masks[step])
                        else:
                            value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                                rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                                rollouts.masks[step])


                    # Obser reward and next obs
                    #####
                    obs, reward, done, infos = envs.step(action)
                    #####
                    if config.render:
                        print('step reward', envs.venv.origStepReward)

                        envs.render()

                    env_rewards=env_rewards+envs.venv.origStepReward
                    if any(done):
                        idx=np.where(done==True)[0]
                        for index in idx:
                            episode_rewards.append(env_rewards[index])
                            env_rewards[index]=0.

                    # If done then clean the history of observations.
                    masks = torch.FloatTensor(
                        [[0.0] if done_ else [1.0] for done_ in done])
                    bad_masks = torch.FloatTensor(
                        [[0.0] if 'bad_transition' in info.keys() else [1.0]
                         for info in infos])
                    rollouts.insert(obs, recurrent_hidden_states, action,
                                    action_log_prob, value, reward, masks, bad_masks)

                with torch.no_grad():
                    if isinstance(rollouts.obs, dict):
                        rollouts_obs = {}
                        for key in rollouts.obs:
                            rollouts_obs[key] = rollouts.obs[key][-1]
                        next_value = actor_critic.get_value(
                            rollouts_obs, rollouts.recurrent_hidden_states[-1],
                            rollouts.masks[-1]).detach()
                    else:
                        next_value = actor_critic.get_value(
                            rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                            rollouts.masks[-1]).detach()

                rollouts.compute_returns(next_value, config.ppoUseGAE, config.RLGamma,
                                         config.ppoGAELambda, config.RLUseProperTimeLimits) 

                value_loss, action_loss, dist_entropy, inSightLoss, exiLoss, soundAuxLoss = agent.update(rollouts)

                rollouts.after_update()

                # save for every interval-th episode or for the last epoch
                if (j % config.RLModelSaveInterval == 0
                    or j == num_updates - 1) and config.RLModelSaveDir != "":
                    save_path = config.RLModelSaveDir

                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    torch.save(actor_critic.state_dict(), os.path.join(save_path, '%.5i'%j + ".pt"), _use_new_zipfile_serialization=False)

                if j % config.RLLogInterval == 0 and len(episode_rewards) > 1:
                    total_num_steps = (j + 1) * config.RLNumEnvs * config.ppoNumSteps
                    end = time.time()
                    print(
                        "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                        "inSight: {}, exi: {} \n"
                            .format(j, total_num_steps,
                                    int(total_num_steps / (end - start)),
                                    len(episode_rewards), np.mean(episode_rewards),
                                    np.median(episode_rewards), np.min(episode_rewards),
                                    np.max(episode_rewards), inSightLoss, exiLoss, dist_entropy, value_loss,
                                    action_loss))

                    df = pd.DataFrame({'misc/nupdates': [j], 'misc/total_timesteps': [total_num_steps],
                                       'fps': int(total_num_steps / (end - start)),
                                       'eprewmean': [np.mean(episode_rewards)],
                                       'min': np.min(episode_rewards),
                                       'max': np.max(episode_rewards),
                                       'loss/policy_entropy': dist_entropy, 'loss/policy_loss': action_loss,
                                       'loss/value_loss': value_loss, 'loss/inSightLoss': inSightLoss,
                                       'loss/exiLoss': exiLoss})

                    if os.path.exists(os.path.join(config.RLModelSaveDir, 'progress.csv')) and j > 20:
                        df.to_csv(os.path.join(config.RLModelSaveDir, 'progress.csv'), mode='a', header=False,
                                  index=False)
                    else:
                        df.to_csv(os.path.join(config.RLModelSaveDir, 'progress.csv'), mode='w', header=True,
                                  index=False)

                if j % config.pretextModelUpdateInterval == 0 and j!=0:
                    print("Update representation at", j, 'policy update')
                    trainRepresentation(model=pretextModel,
                                        epoch=config.pretextModelUpdateEpoch, lr=config.pretextModelUpdateLR, start_ep=representationEpoch)
                    representationEpoch=representationEpoch+config.pretextModelUpdateEpoch

        else: # evaluate the policy
            num_processes=1
            eval_envs = make_vec_envs(env_name=config.RLEnvName,
                                      seed=config.RLEnvSeed,
                                      num_processes=num_processes,
                                      gamma=None,
                                      device=device,
                                      randomCollect=False,
                                      config=config)
            baseEnv=eval_envs.venv.unwrapped.envs[0]
            skillList=[]
            for i,skill_info in enumerate(config.skillInfos):
                assert os.path.exists(skill_info['path'])
                if eval_envs.action_space.__class__.__name__ == "Discrete":
                    action_space=gym.spaces.Discrete(skill_info['actionDim'])
                elif eval_envs.action_space.__class__.__name__ == "Box":
                    high = np.ones(skill_info['actionDim'])
                    action_space = gym.spaces.Box(-high, high, dtype=np.float32)
                else:
                    raise  NotImplementedError
                ac=Policy(
                    eval_envs.venv.observation_space.spaces,
                    action_space,
                    base=config.RLPolicyBase,
                    base_kwargs={'recurrent': config.RLRecurrentPolicy,
                                 'recurrentInputSize': config.RLRecurrentInputSize,
                                 'recurrentSize': config.RLRecurrentSize,
                                 'actionHiddenSize': config.RLActionHiddenSize
                                 })
                print("Load the weights from", skill_info['path'])
                ac.load_state_dict(torch.load(skill_info['path']))
                ac.eval()
                print("Weights Loaded!")
                ac.to(device)
                skillList.append(ac)

            assert len(skillList)!=0

            eval_envs.venv.pretextModel = pretextModel

            eval_episode_rewards = []
            eval_env_rewards = 0.

            obs = eval_envs.reset()

            eval_recurrent_hidden_states = torch.zeros(
                num_processes, skillList[0].recurrent_hidden_state_size, device=device)
            eval_masks = torch.zeros(num_processes, 1, device=device)

            episode_num = baseEnv.size_per_class_cumsum[-1]

            results=[]
            goal_area_count_list = []
            objs=np.arange(config.taskNum, dtype=np.int64)
            objs=np.repeat(objs, baseEnv.size_per_class)

            with torch.no_grad():

                if config.hierarchy:
                    head, tail = os.path.split(config.pretextModelLoadDir)
                    medoidFilePath = os.path.join(head, os.path.splitext(tail)[0] + '_medoids.pickle')
                    if os.path.exists(medoidFilePath):
                        print('Found medoids.pickle, load directly')
                        with open(medoidFilePath, 'rb') as fp:
                            medoids = pickle.load(fp)
                            medoids=torch.from_numpy(medoids).to(device)
                    else:
                        print("medoids.pickle not found. Please run pretext_RSI3.pt to get medoids first")
                        raise FileNotFoundError
                    goal_sound_feat = obs['goal_sound_feat']
                    dotProducts = torch.sum(goal_sound_feat * medoids, dim=1)
                    predicted_task_ID = torch.argmax(dotProducts).item()
                    print('Predicted Task ID', predicted_task_ID)
                    actor_critic = skillList[config.taskID2Skill[predicted_task_ID]]
                else:
                    actor_critic = skillList[0]

            while baseEnv.episodeCounter < episode_num:

                with torch.no_grad():

                    _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                        obs,
                        eval_recurrent_hidden_states,
                        eval_masks,
                        deterministic=True)

                # Obser reward and next obs
                if eval_envs.action_space.__class__.__name__ == "Discrete" and config.hierarchy:
                    if action>5: 
                        action=action+config.skillInfos[config.taskID2Skill[predicted_task_ID]]['actionOffset']
                obs, _, done, infos = eval_envs.step(action)

                if config.render:
                    eval_envs.render()
                    time.sleep(0.5)

                    print('step reward', eval_envs.venv.origStepReward)
                eval_env_rewards = eval_env_rewards + eval_envs.venv.origStepReward

                eval_masks = torch.tensor(
                    [[0.0] if done_ else [1.0] for done_ in done],
                    dtype=torch.float32,
                    device=device)

                if done:
                    goal_area_count=infos[0]['goal_area_count']
                    goal_area_count_list.append(goal_area_count)
                    results.append(int(goal_area_count>=config.success_threshold))
                    eval_episode_rewards.append(float(eval_env_rewards))
                    eval_env_rewards = 0.

                    if config.hierarchy:
                        # calculate nearest neighbor
                        goal_sound_feat = obs['goal_sound_feat']
                        dotProducts = torch.sum(goal_sound_feat * medoids, dim=1)
                        predicted_task_ID = torch.argmax(dotProducts).item()
                        print('Predicted Task ID', predicted_task_ID)
                        actor_critic = skillList[config.taskID2Skill[predicted_task_ID]]
                    else:
                        actor_critic = skillList[0]

            # save the results
            if not config.render:
                df = pd.DataFrame({'objIdx': objs, 'goal area count': goal_area_count_list, 'rewards':eval_episode_rewards, 'results': results})
                if config.hierarchy:
                    save_path=os.path.join(os.path.dirname(config.skillInfos[0]['path']), 'hierarchy.csv')
                else:
                    save_path=os.path.join(os.path.dirname(config.skillInfos[0]['path']), 'test_'+os.path.splitext(os.path.basename(config.skillInfos[0]['path']))[0]+ '.csv')
                df.to_csv(save_path, mode='w', header=True, index=False)
                print('results saved to', save_path)
                print('success rate', sum(results)*1./baseEnv.size_per_class_cumsum[-1])
            eval_envs.close()
