import argparse
import pickle
import gym
import time
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import d4rl # Import required to register environments
import deepdish as dd
import d4rl.gym_mujoco
import os
from reward_learning.utils_mine import *
# from reward_learning.utils import generate_novice_demos


def relabel_rewards(env,dataset,env_name,relabel='dense'):
  target_goal = env.target_goal if 'antmaze' in env_name else env.goal_locations[0]
  print ('Target Goal: ', target_goal)

  all_obs = dataset['observations'][:]

  if relabel == 'dense':
      """reward at the next state = dist(s', g)"""
      _rew = np.exp(-np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1))
  elif relabel == 'sparse':
      _rew = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)
  else:
      _rew = dataset['rewards'][:]

    # Also add terminals here
  if "antmaze" in env_name:
    _terminals = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)
  else:
    _terminals = (np.linalg.norm(all_obs[1:, :2] - target_goal, axis=1) <= 0.5).astype(np.float32)
  _terminals = np.concatenate([_terminals, np.array([0])], 0)
  if "maze2d" in env_name:
    current_length = 0
    for i in range(len(_terminals)):
      if not _terminals[i]:
        current_length+=1
      else:
        current_length = 0
      if current_length>=200:
        _terminals[i] = 1.0
        current_length = 0
  _rew = np.concatenate([_rew, np.array([0])], 0)
  print ('Sum of rewards: ', _rew.sum())
  print ('Sum of terminals: ', _terminals.sum())
  dataset['rewards'] = _rew
  dataset['terminals'] = _terminals
  return dataset


class RewardLearner:
    def __init__(self,env_name,seed,num_queries_per_iter,num_ensembles,env,dataset,traj_length=50,prefix='0'):
        self.env_name = env_name
        self.env =  env
        self.env_prefix = env_name.split('-')[0]
        self.dataset = dataset
        # self.dataset = relabel_rewards(self.env, self.dataset, self.env_name, relabel='dense')
        self.input_dim = self.env.observation_space.shape[0] +  self.env.action_space.shape[0]
        self.traj_length = traj_length
        self.seed = seed
        self.num_queries_per_iter = num_queries_per_iter
        self.num_ensembles = num_ensembles
        self.prefix=prefix

    def init_reward_model(self,initial_pairs,num_iter=20,retrain_num_iter=20):
        self.initial_pairs = initial_pairs
        self.iter_cnt = 0
        self.retrain_num_iter = retrain_num_iter

        path = "./rewards"
        # Check whether the specified path exists or not
        isExist = os.path.exists(path)
        if not isExist:
            # Create a new directory because it does not exist
            os.makedirs(path)

        self.reward_model_path = os.path.join(path,
                                         f'./ensemble_{self.env_name}_initial_pairs_{initial_pairs}_num_queries_{self.num_queries_per_iter}_num_iter_{num_iter}_retrain_num_iter_{retrain_num_iter}_seed_{self.seed}')
        self.active_reward_root = os.path.join(path, f'./ensemble_{self.env_name}_initial_pairs_{initial_pairs}_num_queries_{self.num_queries_per_iter}_num_iter_{num_iter}_retrain_num_iter_{retrain_num_iter}_seed_{self.seed}_round_num_')

        self.lr = 0.00005
        weight_decay = 0.0
        self.num_iter = num_iter  # num times through training data
        self.l1_reg = 0.0
        stochastic = True

        self.demo_list = []
        self.returns_list = []
        self.rewards_list = []
        self.models_list = []
        self.training_obs_list, self.training_labels_list = [], []
        self.optimizers = []
        num_seeds = self.num_ensembles
        # pretrain the reward models with intial number of query pairs
        for seed in range(self.seed, self.seed + num_seeds):
            torch.manual_seed(seed)
            np.random.seed(seed)

            demonstrations, learning_returns, learning_rewards, actions = generate_novice_demos_mine_test(self.dataset,
                                                                                                initial_pairs * 2000,
                                                                                                self.traj_length)
            self.demo_list.append(demonstrations)
            self.returns_list.append(learning_returns)
            self.rewards_list.append(learning_rewards)

            # sort the demonstrations according to ground truth reward to simulate ranked demos
            demo_lengths = [len(d) for d in demonstrations]
            max_snippet_length = 0#min(np.min(demo_lengths), maximum_snippet_length)
            demonstrations = [x for _, x in sorted(zip(learning_returns, demonstrations), key=lambda pair: pair[0])]

            sorted_returns = sorted(learning_returns)

            training_obs, training_labels, training_returns, _ = create_training_data(demonstrations, initial_pairs,
                                                                                      0, 0,
                                                                                      0,
                                                                                      sorted_returns)
            print(training_labels, training_returns)
            # for i in range(len(training_labels)):
            #     training_labels[i] = 1-training_labels[i]
            self.training_obs_list.append(training_obs)
            self.training_labels_list.append(training_labels)

            # Now we create a reward Network and optimize it using the training data.
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            reward_net = Net(self.input_dim)
            reward_net.to(device)
            import torch.optim as optim
            optimizer = optim.Adam(reward_net.parameters(), lr=self.lr, weight_decay=weight_decay)
            self.optimizers.append(optimizer)
            learn_reward(reward_net, optimizer, training_obs, training_labels, num_iter, self.l1_reg, self.reward_model_path)

            self.models_list.append(reward_net)

        # just save the first model
        reward_net = self.models_list[0]

        print('num_queries: ', len(self.training_obs_list[0]))
        # for calculating reward npy
        npy_traj_length = 1000
        npy_num_trajs = int(self.dataset['observations'].shape[0] / npy_traj_length)
        self.npy_demonstrations, _, _, actions = generate_novice_demos(self.dataset, npy_num_trajs, npy_traj_length)

        # acc = calc_accuracy(self.models_list[0], self.large_training_obs, self.large_training_labels)

        # tail case for hopper-medium-expert which has length 1999906
        num_tails = self.dataset['observations'].shape[0] % npy_traj_length
        step_start = self.dataset['observations'].shape[0] - num_tails  # for generate_novice_demos
        self.npy_demonstrations_tail, _, _, actions = generate_novice_demos(self.dataset, num_tails, 1, steps=step_start)

        reward_arr_list = []
        train_reward_arr_list = []
        for model_idx in range(len(self.models_list)):
            reward_net = self.models_list[model_idx]
            train_reward = compute_reward(self.models_list[model_idx], None, self.training_obs_list[model_idx],
                                          self.training_labels_list[model_idx], None,
                                          None, None)
            train_reward_arr_list.append(train_reward)
            reward_arr = np.array(parallel_predict_reward_sequence(reward_net, self.npy_demonstrations))
            reward_arr_tail = np.array(parallel_predict_reward_sequence(reward_net, self.npy_demonstrations_tail))
            reward_arr_comb = np.concatenate((reward_arr, reward_arr_tail), axis=0)
            reward_arr_list.append(reward_arr_comb)

        reward_arr_all = np.array(reward_arr_list)
        reward_arr = np.mean(reward_arr_all, axis=0)
        reward_std = np.std(reward_arr_all, axis=0)
        reward_min = np.min(reward_arr_all, axis=0)
        reward_max = np.max(reward_arr_all, axis=0)
        train_reward_arr_all = np.array(train_reward_arr_list)
        train_reward_arr = np.mean(train_reward_arr_all, axis=0)
        train_reward_std = np.std(train_reward_arr_all, axis=0)
        train_reward_min = np.min(train_reward_arr_all, axis=0)
        train_reward_max = np.max(train_reward_arr_all, axis=0)
        with open(self.active_reward_root + str(0) +'_'+self.prefix+'_all.npy', 'wb') as f:
            np.save(f, reward_arr_all)
        # with open(self.active_reward_root + str(0) + '.npy', 'wb') as f:
        #     np.save(f, reward_arr)
        #
        # with open(self.active_reward_root + str(0) + 'std.npy', 'wb') as f:
        #     np.save(f, reward_std)
        #
        # with open(self.active_reward_root + str(0) + 'max.npy', 'wb') as f:
        #     np.save(f, reward_max)
        #
        # with open(self.active_reward_root + str(0) + 'min.npy', 'wb') as f:
        #     np.save(f, reward_min)
        #
        # with open(self.active_reward_root + str(0) + 'train.npy', 'wb') as f:
        #     np.save(f, train_reward_arr)
        #
        # with open(self.active_reward_root + str(0) + 'trainstd.npy', 'wb') as f:
        #     np.save(f, train_reward_std)
        #
        # with open(self.active_reward_root + str(0) + 'trainmax.npy', 'wb') as f:
        #     np.save(f, train_reward_max)
        #
        # with open(self.active_reward_root + str(0) + 'trainmin.npy', 'wb') as f:
        #     np.save(f, train_reward_min)


        return reward_arr_all


    # def compute_representation(self):

    def prepare_large_dataset(self,learned_rewards,otr_rewards):

        # load a separate demonstrations that contains a lot or all of the trajectories, randomly sample a bunch, demos returns, rewards
        large_num_trajs = int(int(self.dataset['observations'].shape[0] / self.traj_length) // 1.1)  # // 10
        large_num_pairs = large_num_trajs * 5

        self.large_demonstrations, self.large_learning_returns, self.large_learning_rewards, actions,self.learned_returns, self.otr_returns = generate_novice_demos_later(
            self.dataset,
            large_num_pairs,
            self.traj_length,learned_rewards,otr_rewards)

        self.large_demonstrations, self.large_learning_returns, self.learned_returns, self.otr_returns = np.array(self.large_demonstrations), np.array(self.large_learning_returns), np.array(self.learned_returns), np.array(self.otr_returns),
        indices = np.argsort(self.learned_returns)#[::-1]
        # print(self.large_learning_returns[0], self.otr_returns[0], self.large_learning_returns[0])
        self.large_learning_returns, self.otr_returns, self.learned_returns,self.large_demonstrations = self.large_learning_returns[indices], self.otr_returns[indices], self.learned_returns[indices],self.large_demonstrations[indices]
        self.large_sorted_returns = self.large_learning_returns
        # self.large_sorted_returns = sorted(self.large_learning_returns)
        # print(self.large_sorted_returns[0], self.large_sorted_returns[-1], np.mean(self.large_sorted_returns))
        # print(self.learned_returns.shape,self.large_demonstrations.shape,self.large_learning_returns.shape,self.otr_returns.shape)
        # print([[x,y,z] for _, x,y,z in sorted(zip(self.learned_returns, self.large_demonstrations,self.large_learning_returns,self.otr_returns), key=lambda pair: pair[0])])
        # self.large_demonstrations = [x for _, x in sorted(zip(self.learned_returns, self.large_demonstrations), key=lambda pair: pair[0])]
        # print(len(self.large_learning_returns),len(self.otr_returns),len(self.learned_returns))
        # print(self.large_learning_returns[0], self.otr_returns[0], self.learned_returns[0])
        # print(self.large_learning_returns[:10], self.otr_returns[:10], self.learned_returns[:10])
        # self.large_demonstrations = self.large_demonstrations[indices]
        # self.large_learning_returns = self.large_learning_returns[indices]
        # self.otr_returns =  self.otr_returns[indices]
        # self.learned_returns = self.learned_returns[indices]
        top_num = int(large_num_pairs/200)*199
        self.large_demonstrations, self.large_learning_returns, self.otr_returns,self.learned_returns = self.large_demonstrations[-top_num:],self.large_learning_returns[-top_num:], self.otr_returns[-top_num:],self.learned_returns[-top_num:]
        indices = np.argsort(self.large_learning_returns)
        self.large_learning_returns, self.otr_returns, self.learned_returns,self.large_demonstrations = self.large_learning_returns[indices], self.otr_returns[indices], self.learned_returns[indices],self.large_demonstrations[indices]



        # interval = (self.large_sorted_returns[-1] - self.large_sorted_returns[0]) / 20
        # for i in range(4):
        #     upper_bound = self.large_sorted_returns[-1] - interval * i
        #     lower_bound = upper_bound - interval
        #     np_sorted_returns = np.array(self.large_sorted_returns)
        #     qualified = ((np_sorted_returns > lower_bound) * (np_sorted_returns < upper_bound)).mean()
        #     print(lower_bound, upper_bound, qualified)

    def find_new_queries(self,learned_rewards,otr_rewards):
        self.prepare_large_dataset(learned_rewards,otr_rewards)
        arg_sorted = np.argsort(self.otr_returns)
        query_indexes = []
        for i in range(self.num_queries_per_iter):
            if np.random.rand()>0.5:
                query_indexes.append([arg_sorted[i],arg_sorted[-(1+i)]])
            else:
                query_indexes.append([ arg_sorted[-(1 + i)],arg_sorted[i]])
        query_indexes = []
        for i in range(self.num_queries_per_iter):
            ti,tj=0,0
            while ti==tj:
                ti,tj=np.random.randint(0,len(arg_sorted)),np.random.randint(0,len(arg_sorted))
            query_indexes.append([arg_sorted[ti], arg_sorted[tj]])

        return query_indexes

    def find_new_queries_disagreement(self,learned_rewards,otr_rewards):
        self.prepare_large_dataset(learned_rewards,otr_rewards)
        arg_sorted = np.argsort(self.otr_returns)
        query_indexes = []
        for i in range(self.num_queries_per_iter):
            if np.random.rand()>0.5:
                query_indexes.append([arg_sorted[i],arg_sorted[-(1+i)]])
            else:
                query_indexes.append([ arg_sorted[-(1 + i)],arg_sorted[i]])
        query_indexes = []
        for i in range(self.num_queries_per_iter):
            ti,tj=0,0
            while ti==tj:
                ti,tj=np.random.randint(0,len(arg_sorted)),np.random.randint(0,len(arg_sorted))
            query_indexes.append([arg_sorted[ti], arg_sorted[tj]])

        return query_indexes


    def compute_dataset_representation(self):
        repr_arr = np.array(parallel_predict_representation_sequence(self.models_list[0], self.npy_demonstrations))
        print(repr_arr.shape)
        if len(self.npy_demonstrations_tail)>0:
            repr_tail = np.array(parallel_predict_representation_sequence(self.models_list[0], self.npy_demonstrations_tail))
            repr_arr_comb = np.concatenate((repr_arr, repr_tail), axis=0)
        else:
            repr_arr_comb = repr_arr
        return repr_arr_comb



    def learn_reward_later(self,query_list):
        total_large_training_obs, total_large_training_labels, total_large_training_returns = [], [], []
        self.iter_cnt+=1
        for i in range(self.num_ensembles):
            large_training_obs, large_training_labels,large_training_returns = [],[],[]
            query_indexes = query_list[i]
            print(len(query_indexes))
            for index_pair in query_indexes:
                if index_pair[0] > index_pair[1]:
                    label = 0
                else:
                    label = 1

                large_training_obs.append((self.large_demonstrations[index_pair[0]], self.large_demonstrations[index_pair[1]]))
                large_training_labels.append(label)
                large_training_returns.append([self.large_sorted_returns[index_pair[0]], self.large_sorted_returns[index_pair[0]]])

            total_large_training_obs.append(large_training_obs)
            total_large_training_labels.append(large_training_labels)
            total_large_training_returns.append(large_training_returns)





        for i in range(len(self.training_obs_list)):
            for idx in range(len(total_large_training_obs[i])):
                self.training_obs_list[i].append(total_large_training_obs[i][idx])
                self.training_labels_list[i].append(total_large_training_labels[i][idx])

        print('num_queries: ', len(self.training_obs_list[0]))
        # retrain the reward models
        cnt = 0
        for i in range(self.seed,self.seed+self.num_ensembles):
            torch.manual_seed(i)
            np.random.seed(i)

            training_obs, training_labels = self.training_obs_list[cnt], self.training_labels_list[cnt]

            # Now we create a reward network and optimize it using the training data.
            reward_net = self.models_list[cnt]
            # import torch.optim as optim
            optimizer = self.optimizers[cnt]
            learn_reward(reward_net, optimizer, training_obs, training_labels, self.retrain_num_iter, self.l1_reg,
                         self.reward_model_path)
            cnt+=1

        # just save the first model
        reward_net = self.models_list[0]
        # round_idx = round + 1

        acc = calc_accuracy(self.models_list[0], large_training_obs, large_training_labels)
        print(acc)

        reward_arr_list = []
        train_reward_arr_list = []
        for model_idx in range(len(self.models_list)):
            reward_net = self.models_list[model_idx]
            train_reward = compute_reward(self.models_list[model_idx], None, self.training_obs_list[model_idx],
                                          self.training_labels_list[model_idx], None, None, None)
            train_reward_arr_list.append(train_reward)
            reward_arr = np.array(parallel_predict_reward_sequence(reward_net, self.npy_demonstrations))
            reward_arr_tail = np.array(parallel_predict_reward_sequence(reward_net, self.npy_demonstrations_tail))
            reward_arr_comb = np.concatenate((reward_arr, reward_arr_tail), axis=0)
            reward_arr_list.append(reward_arr_comb)

        reward_arr_all = np.array(reward_arr_list)
        reward_arr = np.mean(reward_arr_all, axis=0)
        reward_std = np.std(reward_arr_all, axis=0)
        reward_min = np.min(reward_arr_all, axis=0)
        reward_max = np.max(reward_arr_all, axis=0)
        train_reward_arr_all = np.array(train_reward_arr_list)
        print(train_reward_arr_all.shape)
        train_reward_arr = np.mean(train_reward_arr_all, axis=0)
        train_reward_std = np.std(train_reward_arr_all, axis=0)
        train_reward_min = np.min(train_reward_arr_all, axis=0)
        train_reward_max = np.max(train_reward_arr_all, axis=0)
        # with open(self.active_reward_root + str(self.iter_cnt + 1) + '.npy', 'wb') as f:
        #     np.save(f, reward_arr)

        with open(self.active_reward_root + str(self.iter_cnt + 1) +'_'+self.prefix+'_all.npy', 'wb') as f:
            np.save(f, reward_arr_all)

        # with open(self.active_reward_root + str(self.iter_cnt + 1) + 'std.npy', 'wb') as f:
        #     np.save(f, reward_std)
        #
        # with open(self.active_reward_root + str(self.iter_cnt + 1) + 'max.npy', 'wb') as f:
        #     np.save(f, reward_max)
        #
        # with open(self.active_reward_root + str(self.iter_cnt + 1) + 'min.npy', 'wb') as f:
        #     np.save(f, reward_min)
        #
        # with open(self.active_reward_root + str(self.iter_cnt + 1) + 'train.npy', 'wb') as f:
        #     np.save(f, train_reward_arr)
        #
        # with open(self.active_reward_root + str(self.iter_cnt + 1) + 'trainstd.npy', 'wb') as f:
        #     np.save(f, train_reward_std)
        #
        # with open(self.active_reward_root + str(self.iter_cnt + 1) + 'trainmax.npy', 'wb') as f:
        #     np.save(f, train_reward_max)
        #
        # with open(self.active_reward_root + str(self.iter_cnt + 1) + 'trainmin.npy', 'wb') as f:
        #     np.save(f, train_reward_min)

        return reward_arr_all

# if __name__=="__main__":
#     parser = argparse.ArgumentParser(description=None)
#     parser.add_argument('--env_name', default='', help='Select the environment name to run, i.e. maze2d-medium-dense-v1')
#     parser.add_argument('--initial_pairs', default = 10, type=int, help="initial number of pairs of trajectories used to train the reward models")
#     parser.add_argument('--num_snippets', default = 0, type = int, help = "number of short subtrajectories to sample")
#     parser.add_argument('--voi', default='', help='Choose between infogain, disagreement, or random')
#     parser.add_argument('--num_rounds', default = 0, type = int, help = "number of rounds of active querying")
#     parser.add_argument('--num_queries', default = 1, type = int, help = "number of queries per round of active querying")
#     parser.add_argument('--num_iter', default = 5, type = int, help = "number of iteration of initial data")
#     parser.add_argument('--retrain_num_iter', default = 1, type = int, help = "number of training iteration after one round of active querying")
#     parser.add_argument('--num_ensembles', default = 7, type = int, help = "number of ensemble of members")
#     parser.add_argument('--seed', default = 0, type = int, help = "random seed")
#     parser.add_argument('--beta', default = 10, type = int, help = "beta as a measure of confidence for info gain")
#
#     args = parser.parse_args()
#
#     # Torch RNG
#     torch.manual_seed(args.seed)
#     torch.cuda.manual_seed(args.seed)
#     torch.cuda.manual_seed_all(args.seed)
#     # Python RNG
#     np.random.seed(args.seed)
#
#     env_name = args.env_name
#     list_env_name = list(env_name.split("-"))
#     maze_name = list_env_name[1]
