import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
import os
import matplotlib.pyplot as plt
import random
import gym
import d4rl
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path
import pbrl



def traj_generation_from_d4rl(dataset, clip_len, clip_num):
    
    traj_obs = np.empty([clip_num, clip_len + 1, dataset['observations'].shape[1]])
    traj_act = np.empty([clip_num, clip_len, dataset['actions'].shape[1]])
    traj_rew = np.empty([clip_num, clip_len, 1])

    obs = dataset['observations']
    act = dataset['actions']
    rew = dataset['rewards']

    for i in range(clip_num):
        traj, _ = pbrl.get_random_trajectory_reward(dataset, clip_len) 
        traj_rew[i] = rew[traj].reshape(-1, 1)
        traj_act[i] = act[traj]
        traj_obs[i] = obs[np.append(traj, traj[-1] + 1)]
        
    return np.array(traj_obs), np.array(traj_act), np.array(traj_rew)

def traj_generation_from_d4rl_no_overlap(dataset, clip_len, clip_num):
    
    traj_obs = np.empty([clip_num, clip_len + 1, dataset['observations'].shape[1]])
    traj_act = np.empty([clip_num, clip_len, dataset['actions'].shape[1]])
    traj_rew = np.empty([clip_num, clip_len, 1])
    
    starting_indices = list(range(0, len(dataset['observations'])-clip_len+1, clip_len))

    obs = dataset['observations']
    act = dataset['actions']
    rew = dataset['rewards']

    for i in range(clip_num):
        traj, _ = pbrl.pick_and_calc_reward(dataset, starting_indices, clip_len) 

        traj_rew[i] = rew[traj].reshape(-1, 1)
        traj_act[i] = act[traj]
        if traj[-1] + 1 >= len(dataset['observations']):
            traj_obs[i] = obs[np.append(traj, traj[-1])]
        else:
            traj_obs[i] = obs[np.append(traj, traj[-1] + 1)]

    return np.array(traj_obs), np.array(traj_act), np.array(traj_rew)

def add_preference_label(dataset, data_num):
    traj_obs, traj_act, traj_rew = dataset
    traj_idx_1 = np.array(random.sample(range(traj_rew.shape[0]), data_num))
    traj_idx_2 = np.array(random.sample(range(traj_rew.shape[0]), data_num))
     
    pref = np.empty([data_num])
    
    for i in range(data_num):
        traj_obs1, traj_act1, traj_rew1, traj_obs2, traj_act2, traj_rew2 = traj_obs[traj_idx_1[i]], \
            traj_act[traj_idx_1[i]], traj_rew[traj_idx_1[i]], traj_obs[traj_idx_2[i]], traj_act[traj_idx_2[i]], \
            traj_rew[traj_idx_2[i]]
        
        utility_1, utility_2 = np.sum(traj_rew1), np.sum(traj_rew2)
        prob = np.exp(utility_1) / (np.exp(utility_1) + np.exp(utility_2))
        if random.random() < prob:
            pref[i] = 1
        else:
            pref[i] = 2
    
    return (*dataset, traj_idx_1, traj_idx_2, pref)




def dataset_from_env(config):
    ds_path = Path(f'../dataset/{config.env}_{config.data_num}_{config.seed}')
    if ds_path.exists():
        dataset = np.load(str(ds_path))
        return (dataset['obs'], dataset['act'], dataset['rew'], dataset['idx_1'], dataset['idx_2'], dataset['pref'])
    
    env = gym.make(config.env)
    dataset = d4rl.qlearning_dataset(env)
    
    len_t = config.clip_len
    num_t = (config.data_num if config.data_num is not None else int(len(dataset['observations'])/len_t//2))

    if config.bin_label_allow_overlap:
        dataset = traj_generation_from_d4rl(dataset, len_t,num_t)
    else:
        dataset = traj_generation_from_d4rl_no_overlap(dataset, len_t,2 *num_t)
        
    dataset = add_preference_label(dataset, num_t)
    directory = os.path.dirname(ds_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    np.savez(ds_path, obs = dataset[0], act = dataset[1], rew = dataset[2], idx_1 = dataset[3], idx_2 = dataset[4], pref = dataset[5])
    return dataset

