import argparse
import pickle
import gym
import time
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import d4rl # Import required to register environments
import deepdish as dd
import d4rl.gym_mujoco
import os
from reward_learning.utils import *
from reward_learning.utils import generate_novice_demos

if __name__=="__main__":
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument('--env_name', default='', help='Select the environment name to run, i.e. maze2d-medium-dense-v1')
    parser.add_argument('--initial_pairs', default = 10, type=int, help="initial number of pairs of trajectories used to train the reward models")
    parser.add_argument('--num_snippets', default = 0, type = int, help = "number of short subtrajectories to sample")
    parser.add_argument('--voi', default='', help='Choose between infogain, disagreement, or random')
    parser.add_argument('--num_rounds', default = 0, type = int, help = "number of rounds of active querying")
    parser.add_argument('--num_queries', default = 1, type = int, help = "number of queries per round of active querying")
    parser.add_argument('--num_iter', default = 5, type = int, help = "number of iteration of initial data")
    parser.add_argument('--retrain_num_iter', default = 1, type = int, help = "number of training iteration after one round of active querying")
    parser.add_argument('--num_ensembles', default = 7, type = int, help = "number of ensemble of members")
    parser.add_argument('--seed', default = 0, type = int, help = "random seed")
    parser.add_argument('--beta', default = 10, type = int, help = "beta as a measure of confidence for info gain")

    args = parser.parse_args()

    # Torch RNG
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # Python RNG
    np.random.seed(args.seed)

    env_name = args.env_name
    list_env_name = list(env_name.split("-"))
    maze_name = list_env_name[1]
    env = gym.make(args.env_name)

    env_prefix = env_name.split('-')[0]
    dataset = env.get_dataset()

    #set input_dim based on environment
    if env_prefix == 'maze2d':
        input_dim = 4
    elif env_prefix == 'halfcheetah':
        input_dim = 17
    elif env_prefix == 'hopper':
        input_dim = 11
    elif env_prefix == 'kitchen':
        input_dim = 60
    elif env_name == 'flow-ring-random-v1' or env_name == 'flow-ring-random-v0':
        input_dim = 4
    elif env_name == 'flow-merge-random-v1' or env_name == 'flow-merge-random-v0':
        input_dim = 60
    input_dim = env.observation_space.shape[0]
    initial_pairs = args.initial_pairs
    num_snippets = args.num_snippets
    min_snippet_length = 25 #min length of trajectory for training comparison
    maximum_snippet_length = 100

    traj_length = 50
    num_iter = args.num_iter
    retrain_num_iter = args.retrain_num_iter
    #arg
    num_queries = args.num_queries
    voi = args.voi
    num_rounds = args.num_rounds

    num_seeds = args.num_ensembles

    beta = args.beta

    #check if a directory exists
    path = "./rewards"
    # Check whether the specified path exists or not
    isExist = os.path.exists(path)
    if not isExist:
        # Create a new directory because it does not exist
        os.makedirs(path)

    reward_model_path = os.path.join(path, f'./ensemble_{env_name}_initial_pairs_{initial_pairs}_num_queries_{num_queries}_num_iter_{num_iter}_retrain_num_iter_{retrain_num_iter}_voi_{voi}_seed_{args.seed}')
    active_reward_root = os.path.join(path, f'./ensemble_{env_name}_initial_pairs_{initial_pairs}_num_queries_{num_queries}_num_iter_{num_iter}_retrain_num_iter_{retrain_num_iter}_voi_{voi}_seed_{args.seed}_round_num_')

    lr = 0.00005
    # lr=1e-3
    weight_decay = 0.0
    num_iter = args.num_iter #num times through training data
    l1_reg=0.0
    stochastic = True

    demo_list = []
    returns_list = [] 
    rewards_list = []
    models_list = []
    training_obs_list, training_labels_list = [], []





    print(int(len(dataset['observations'])/traj_length))
    demonstrations, learning_returns, learning_rewards = generate_novice_demos_array(dataset, int(len(dataset['observations'])/traj_length), traj_length)
    demo_list.append(demonstrations)
    returns_list.append(learning_returns)
    rewards_list.append(learning_rewards)

    # sort the demonstrations according to ground truth reward to simulate ranked demos
    demo_lengths = [len(d) for d in demonstrations]
    max_snippet_length = min(np.min(demo_lengths), maximum_snippet_length)
    demonstrations = [x for _, x in sorted(zip(learning_returns, demonstrations), key=lambda pair: pair[0])]


    aranges=np.arange(len(demonstrations))
    array_demonstrations = np.array(demonstrations)
    print(array_demonstrations.shape)
    low_return_demos = array_demonstrations[np.random.choice(aranges,300),:,:]
    print(demonstrations[0].shape,len(demonstrations),len(low_return_demos))
    high_return_demos = array_demonstrations[-100:,:,:]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # low_return_demos = np.concatenate(low_return_demos,0)
    # high_return_demos = np.concatenate(high_return_demos,0)
    low_return_demos = torch.from_numpy(low_return_demos).float().to(device)
    high_return_demos = torch.from_numpy(high_return_demos).float().to(device)
    total_demos = np.concatenate(demonstrations,0)
    total_demos = torch.from_numpy(total_demos).float().to(device)
    print(high_return_demos.shape,low_return_demos.shape,total_demos.shape)

    #pretrain the reward models with intial number of query pairs
    for seed in range(10,num_seeds+10):
        torch.manual_seed(seed)
        np.random.seed(seed)

        # low_return_demos = array_demonstrations[np.random.choice(aranges, 300), :, :]
        # low_return_demos = torch.from_numpy(low_return_demos).float().to(device)


        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        reward_net = Net(input_dim)
        reward_net.to(device)
        import torch.optim as optim
        optimizer = optim.Adam(reward_net.parameters(),  lr=lr, weight_decay=weight_decay)
        learn_reward_contrastive(reward_net, optimizer, high_return_demos, low_return_demos,total_demos)

        models_list.append(reward_net)
    
    #just save the first model
    reward_net = models_list[0]

    #load a separate demonstrations that contains a lot or all of the trajectories, randomly sample a bunch, demos returns, rewards
    large_num_trajs =  int(dataset['observations'].shape[0] / traj_length) // 10
    large_num_pairs = large_num_trajs * 5
    large_demonstrations, large_learning_returns, large_learning_rewards = generate_novice_demos(dataset, large_num_trajs, traj_length)

    #sort the demonstrations according to ground truth reward to simulate ranked demos
    large_demo_lengths = [len(d) for d in large_demonstrations]
    large_max_snippet_length = min(np.min(large_demo_lengths), maximum_snippet_length)
    sorted_large_demonstrations = [x for _, x in sorted(zip(large_learning_returns,large_demonstrations), key=lambda pair: pair[0])]

    large_sorted_returns = sorted(large_learning_returns)
    print(large_sorted_returns[0],large_sorted_returns[-1])

    large_training_obs, large_training_labels = create_training_data(sorted_large_demonstrations, large_num_pairs, num_snippets, min_snippet_length, max_snippet_length)

    # large_training_obs, large_training_labels = create_training_data_greedy(sorted_large_demonstrations, large_num_pairs,
    #                                                                  num_snippets, min_snippet_length,
    #                                                                  max_snippet_length)

    #for calculating reward npy
    npy_traj_length =  1000
    npy_num_trajs =  int(dataset['observations'].shape[0] / npy_traj_length)
    npy_demonstrations, _, _ = generate_novice_demos(dataset, npy_num_trajs, npy_traj_length)

    acc = calc_accuracy(models_list[0], large_training_obs, large_training_labels)

    #tail case for hopper-medium-expert which has length 1999906
    num_tails = dataset['observations'].shape[0] % npy_traj_length
    step_start = dataset['observations'].shape[0] - num_tails #for generate_novice_demos
    npy_demonstrations_tail, _, _ = generate_novice_demos(dataset, num_tails, 1, steps=step_start)

    reward_arr_list = []
    train_reward_arr_list = []
    for model_idx in range(len(models_list)):
        reward_net = models_list[model_idx]
        print(models_list[model_idx].net(low_return_demos).shape)
        train_reward = models_list[model_idx].net(low_return_demos).squeeze() #.mean(0,keepdim=True)
        train_reward_2 = models_list[model_idx].net(high_return_demos).squeeze() #.mean(0,keepdim=True)
        train_reward = torch.cat([train_reward,train_reward_2],0)
        train_reward = torch.flatten(train_reward)
        print(train_reward.shape)
        train_reward_arr_list.append(train_reward.cpu().data.numpy())
        reward_arr = np.array(parallel_predict_reward_sequence(reward_net, npy_demonstrations))
        reward_arr_tail = np.array(parallel_predict_reward_sequence(reward_net, npy_demonstrations_tail))
        reward_arr_comb = np.concatenate((reward_arr, reward_arr_tail), axis=0)
        reward_arr_list.append(reward_arr_comb)

    reward_arr_all = np.array(reward_arr_list)
    reward_arr = np.mean(reward_arr_all, axis=0)
    reward_std = np.std(reward_arr_all, axis=0)
    reward_min = np.min(reward_arr_all, axis=0)
    reward_max = np.max(reward_arr_all, axis=0)
    train_reward_arr_all = np.array(train_reward_arr_list)
    print(reward_arr_all.shape,train_reward_arr_all.shape)
    train_reward_arr = np.mean(train_reward_arr_all, axis=0)
    train_reward_std = np.std(train_reward_arr_all, axis=0)
    train_reward_min = np.min(train_reward_arr_all, axis=0)
    train_reward_max = np.max(train_reward_arr_all, axis=0)
    with open(active_reward_root+str(0)+'.npy', 'wb') as f:
        np.save(f, reward_arr)

    with open(active_reward_root+str(0)+'std.npy', 'wb') as f:
        np.save(f, reward_std)

    with open(active_reward_root + str(0) + 'max.npy', 'wb') as f:
        np.save(f, reward_max)

    with open(active_reward_root + str(0) + 'min.npy', 'wb') as f:
        np.save(f, reward_min)

    with open(active_reward_root + str(0) + 'train.npy', 'wb') as f:
        np.save(f, train_reward_arr)

    with open(active_reward_root + str(0) + 'trainstd.npy', 'wb') as f:
        np.save(f, train_reward_std)

    with open(active_reward_root + str(0) + 'trainmax.npy', 'wb') as f:
        np.save(f, train_reward_max)

    with open(active_reward_root + str(0) + 'trainmin.npy', 'wb') as f:
        np.save(f, train_reward_min)

