import numpy as np
import torch
import torch.nn.functional as F
import time
import itertools
import tqdm

def np_sigmoid(x):
    z = 1/(1 + np.exp(-x))
    return z



def maxentirl_loss_sigmoid_validation(div: str, agent_samples, expert_samples, reward_func, reward_optimizer,  device, regularization='vanilla', epochs=1):
    ''' NOTE: only for maxentirl (FKL in trajectory): E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    assert div in ['rank-pal']
    sA, _, _ = agent_samples
    _, T, d = sA.shape

    sA = torch.FloatTensor(sA)
    sE = torch.FloatTensor(expert_samples).reshape(-1,T, d)
    loss_fn = torch.nn.MSELoss()

    ex_id = np.random.permutation(sE.shape[0])
    # sE = sE[ex_id]

    intermediate_samples = np.arange(0.0,1.1,0.2) #np.arange(-0.4,1.4,0.2)
    sigmoid_range = 6

    # Generate dataset
    traj_dataset = None
    label_dataset = None
    for mag  in intermediate_samples:
        # sE_sample = sE[np.random.choice(min(sA.shape[0],sE.shape[0]))]
        sE_sample = sE[np.random.choice(sE.shape[0],size=sA.shape[0])]
        # # Hardcode for lift environment
        # for i in range(sA.shape[0]):
        #     closest_idx = torch.argmin(torch.abs((sE-sA[i:i+1])[:,:,-10:]).sum(2).sum(1)).detach().cpu().numpy()
        #     sE_sample[i] = sE[closest_idx]
        
        sM = sA + mag*(sE_sample-sA)
        sM_vec = sM
        if(traj_dataset is None):
            traj_dataset =sM_vec
            label_dataset = np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            elif 'exp' in regularization:
                split_words = regularization.split('-')
                if 'n' in split_words[-1]:
                    slope = -int(split_words[-1][1:])
                else:
                    slope = int(split_words[-1])
                
                p = 20/(np.exp(slope)-1)
                q = -10*(1+np.exp(slope))/(np.exp(slope)-1)
                val = q+p*np.exp(slope*mag)
            label_dataset*=int(val)
            # print(val)
        else:
            traj_dataset = np.concatenate((traj_dataset,sM_vec),axis=0)
            labels =np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
                
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            elif 'exp' in regularization:
                split_words = regularization.split('-')
                if 'n' in split_words[-1]:
                    slope = -int(split_words[-1][1:])
                else:
                    slope = int(split_words[-1])
                p = 20/(np.exp(slope)-1)
                q = -10*(1+np.exp(slope))/(np.exp(slope)-1)
                val = q+p*np.exp(slope*mag)
            # print(val)
            labels*=int(val)
            label_dataset = np.concatenate((label_dataset,labels),axis=0)
    print("Dataset prepared")
    state_dataset = traj_dataset.reshape(-1,d)
    state_label_dataset = label_dataset.reshape(-1)
    idx = np.arange(state_dataset.shape[0])
    holdout_size = int(0.1*(state_dataset.shape[0]))
    train_dataset = state_dataset[idx[holdout_size:],:]
    train_label = state_label_dataset[idx[holdout_size:]]

    holdout_dataset = state_dataset[idx[:holdout_size],:]
    holdout_label = state_label_dataset[idx[:holdout_size]]

    if epochs==-1:
        max_epoch_since_update = 5
        best = None
    else:
        max_epoch_since_update = epochs
        best = np.inf
    
    epoch_since_update = 0
 
    
    pbar = tqdm.trange(100, desc="Training reward")
    for epoch in pbar: #tqdm(itertools.count()):
        idx = np.arange(train_dataset.shape[0])
        np.random.shuffle(idx)
        batch_size = 1024
        for i in range(0,idx.shape[0],batch_size):
            train_x = train_dataset[idx[i:min(i+batch_size,idx.shape[0])]]
            label_x = train_label[idx[i:min(i+batch_size,idx.shape[0])]].reshape(-1,1)
            # import ipdb;ipdb.set_trace()
            t1 = reward_func.r(torch.FloatTensor(train_x).to(device))
            # import ipdb;ipdb.set_trace()
            loss_val= loss_fn(t1,torch.FloatTensor(label_x).to(device))
            # print("Memory usage train1: ", process.memory_info().rss/ (1024 ** 2))
            reward_optimizer.zero_grad()
            loss_val.backward()
            reward_optimizer.step()
        holdout_loss = 0
        batch_ctr = 0
        for i in range(0,holdout_dataset.shape[0],batch_size):
            holdout_x = holdout_dataset[i:min(i+batch_size,holdout_dataset.shape[0])]
            holdout_label_x = holdout_label[i:min(i+batch_size,holdout_dataset.shape[0])].reshape(-1,1)
            t1 = reward_func.r(torch.FloatTensor(holdout_x).to(device))
            loss_val= loss_fn(t1,torch.FloatTensor(holdout_label_x).to(device))
            # print("Memory usage holdout1: ", process.memory_info().rss/ (1024 ** 2))
            holdout_loss+= loss_val.item()*(min(i+batch_size,holdout_dataset.shape[0])-i)
        holdout_loss/= (holdout_dataset.shape[0]+1)
        pbar.set_description("Holdout loss: {}".format(holdout_loss))
        # print("Holdout loss: {}".format(holdout_loss))
        if(best is None or ((best-holdout_loss)/best>0.01)):
            epoch_since_update = 0
            best = holdout_loss
        else:
            epoch_since_update+=1
        if(epoch_since_update>max_epoch_since_update):
            break
        

    return  loss_val.item() 

def maxentirl_loss_sigmoid_validation_offline(div: str, agent_samples, expert_samples, offline_samples, reward_func, reward_optimizer, itr,  device, regularization='vanilla', epochs=1):
    ''' NOTE: only for maxentirl (FKL in trajectory): E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    assert div in ['maxentirl-sigmoid-validation-offline']
    sA, _, _ = agent_samples
    _, T, d = sA.shape

    sA = np.concatenate((sA,offline_samples), axis=0)
    # import ipdb;ipdb.set_trace()
    sA = torch.FloatTensor(sA)
    sE = torch.FloatTensor(expert_samples).reshape(-1,T, d)
    loss_fn = torch.nn.MSELoss()

    ex_id = np.random.permutation(sE.shape[0])
    # sE = sE[ex_id]

    intermediate_samples = np.arange(0.0,1.1,0.2) #np.arange(-0.4,1.4,0.2)
    sigmoid_range = 6

    # Generate dataset
    traj_dataset = None
    label_dataset = None
    for mag  in intermediate_samples:
        # sE_sample = sE[np.random.choice(min(sA.shape[0],sE.shape[0]))]
        sE_sample = sE[np.random.choice(sE.shape[0],size=sA.shape[0])]
        # # Hardcode for lift environment
        # for i in range(sA.shape[0]):
        #     closest_idx = torch.argmin(torch.abs((sE-sA[i:i+1])[:,:,-10:]).sum(2).sum(1)).detach().cpu().numpy()
        #     sE_sample[i] = sE[closest_idx]
        
        sM = sA + mag*(sE_sample-sA)
        sM_vec = sM
        if(traj_dataset is None):
            traj_dataset =sM_vec
            label_dataset = np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            label_dataset*=int(val)
            # print(val)
        else:
            traj_dataset = np.concatenate((traj_dataset,sM_vec),axis=0)
            labels =np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
                
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            # print(val)
            labels*=int(val)
            label_dataset = np.concatenate((label_dataset,labels),axis=0)
    print("Dataset prepared")
    state_dataset = traj_dataset.reshape(-1,d)
    state_label_dataset = label_dataset.reshape(-1)
    idx = np.arange(state_dataset.shape[0])
    holdout_size = int(0.1*(state_dataset.shape[0]))
    train_dataset = state_dataset[idx[holdout_size:],:]
    train_label = state_label_dataset[idx[holdout_size:]]

    holdout_dataset = state_dataset[idx[:holdout_size],:]
    holdout_label = state_label_dataset[idx[:holdout_size]]

    max_epoch_since_update = 5
    epoch_since_update = 0
    if itr==0:
        best = None
    else:
        best = np.inf
    pbar = tqdm.trange(500, desc="Training reward")
    for epoch in pbar: #tqdm(itertools.count()):
        idx = np.arange(train_dataset.shape[0])
        np.random.shuffle(idx)
        batch_size = 1024
        for i in range(0,idx.shape[0],batch_size):
            train_x = train_dataset[idx[i:min(i+batch_size,idx.shape[0])]]
            label_x = train_label[idx[i:min(i+batch_size,idx.shape[0])]].reshape(-1,1)
            # import ipdb;ipdb.set_trace()
            t1 = reward_func.r(torch.FloatTensor(train_x).to(device))
            # import ipdb;ipdb.set_trace()
            loss_val= loss_fn(t1,torch.FloatTensor(label_x).to(device))
            
            reward_optimizer.zero_grad()
            loss_val.backward()
            reward_optimizer.step()
        holdout_loss = 0
        batch_ctr = 0
        for i in range(0,holdout_dataset.shape[0],batch_size):
            holdout_x = holdout_dataset[i:min(i+batch_size,holdout_dataset.shape[0])]
            holdout_label_x = holdout_label[i:min(i+batch_size,holdout_dataset.shape[0])].reshape(-1,1)
            t1 = reward_func.r(torch.FloatTensor(holdout_x).to(device))
            loss_val= loss_fn(t1,torch.FloatTensor(holdout_label_x).to(device))
            holdout_loss+= loss_val.item()*(min(i+batch_size,holdout_dataset.shape[0])-i)
        holdout_loss/= (holdout_dataset.shape[0]+1)
        pbar.set_description("Holdout loss: {}".format(holdout_loss))
        # print("Holdout loss: {}".format(holdout_loss))
        if(best is None or ((best-holdout_loss)/best>0.01)):
            epoch_since_update = 0
            best = holdout_loss
        else:
            epoch_since_update+=1
        if(epoch_since_update>max_epoch_since_update):
            break
        

    return  loss_val.item() 



def maxentirl_loss_sigmoid_validation_preferences(div: str, agent_samples, expert_samples, preference_dataset ,reward_func, reward_optimizer,  device, regularization='vanilla', epochs=1):
    ''' NOTE: only for maxentirl (FKL in trajectory): E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    assert div in ['maxentirl-sigmoid-validation-preferences']
    sA, _, _ = agent_samples
    _, T, d = sA.shape

    sA = torch.FloatTensor(sA)
    sE = torch.FloatTensor(expert_samples).reshape(-1,T, d)
    loss_fn = torch.nn.MSELoss()

    # add preferences to the dataset
    pf_dataset_low = sA
    pf_dataset_high = sE[np.random.choice(sE.shape[0],size=sA.shape[0]),:,:]
    
    for i in range(preference_dataset['levels']-1):
        pf1 = preference_dataset['state'][i*preference_dataset['episode_per_levels']:((i+1)*preference_dataset['episode_per_levels']),:,:]
        pf2 = preference_dataset['state'][(i+1)*preference_dataset['episode_per_levels']:((i+2)*preference_dataset['episode_per_levels']),:,:]
        # print(pf_dataset_low.shape, pf_dataset_high.shape)
        # import ipdb;ipdb.set_trace()
        pf_dataset_low = torch.cat((pf_dataset_low,torch.FloatTensor(pf1)),dim=0)
        pf_dataset_high = torch.cat((pf_dataset_high,torch.FloatTensor(pf2)),dim=0)
        
        # pf_dataset_low = np.concatenate((pf_dataset_low,pf1),axis=0)
        # pf_dataset_high = np.concatenate((pf_dataset_high,pf2), axis=0)

        if(i==preference_dataset['levels']-2): # grounding with respect to expert
             pf_dataset_low = torch.cat((pf_dataset_low,torch.FloatTensor(pf2)),dim=0)
             pf_dataset_high = torch.cat((pf_dataset_high,torch.FloatTensor(sE[np.random.choice(sE.shape[0],size=pf2.shape[0]),:,:])), dim=0)
    
    sA = pf_dataset_low
    sE = pf_dataset_high
    
    # idx = np.random.permutation(pf_dataset_low.shape[0])
    # pf_dataset_low = pf_dataset_low[idx]
    # pf_dataset_high = pf_dataset_high[idx]
    



    ex_id = np.random.permutation(sE.shape[0])
    # sE = sE[ex_id]

    intermediate_samples = np.arange(0.0,1.1,0.2) #np.arange(-0.4,1.4,0.2)
    sigmoid_range = 6

    # Generate dataset
    traj_dataset = None
    label_dataset = None
    for mag  in intermediate_samples:
        # import ipdb;ipdb.set_trace()
        # sE_sample = sE[np.random.choice(min(sA.shape[0],sE.shape[0]))]
        sM = sA + mag*(sE-sA)
        sM_vec = sM
        if(traj_dataset is None):
            traj_dataset =sM_vec
            label_dataset = np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            label_dataset*=int(val)
            # print(val)
        else:
            traj_dataset = np.concatenate((traj_dataset,sM_vec),axis=0)
            labels =np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
                
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            # print(val)
            labels*=int(val)
            label_dataset = np.concatenate((label_dataset,labels),axis=0)

    state_dataset = traj_dataset.reshape(-1,d)
    state_label_dataset = label_dataset.reshape(-1)
    idx = np.arange(state_dataset.shape[0])
    holdout_size = int(0.1*(state_dataset.shape[0]))
    train_dataset = state_dataset[idx[holdout_size:],:]
    train_label = state_label_dataset[idx[holdout_size:]]

    holdout_dataset = state_dataset[idx[:holdout_size],:]
    holdout_label = state_label_dataset[idx[:holdout_size]]

    max_epoch_since_update = 5
    epoch_since_update = 0
    best = np.inf

    for epoch in itertools.count():
        idx = np.arange(train_dataset.shape[0])
        np.random.shuffle(idx)
        batch_size = 1024
        for i in range(0,idx.shape[0],batch_size):
            train_x = train_dataset[idx[i:min(i+batch_size,idx.shape[0])]]
            label_x = train_label[idx[i:min(i+batch_size,idx.shape[0])]].reshape(-1,1)
            # import ipdb;ipdb.set_trace()
            t1 = reward_func.r(torch.FloatTensor(train_x).to(device))
            loss_val= loss_fn(t1,torch.FloatTensor(label_x).to(device))
            reward_optimizer.zero_grad()
            loss_val.backward()
            reward_optimizer.step()
        holdout_loss = 0
        batch_ctr = 0
        for i in range(0,holdout_dataset.shape[0],batch_size):
            holdout_x = holdout_dataset[i:min(i+batch_size,holdout_dataset.shape[0])]
            holdout_label_x = holdout_label[i:min(i+batch_size,holdout_dataset.shape[0])].reshape(-1,1)
            t1 = reward_func.r(torch.FloatTensor(holdout_x).to(device))
            loss_val= loss_fn(t1,torch.FloatTensor(holdout_label_x).to(device))
            holdout_loss+= loss_val.item()*(min(i+batch_size,holdout_dataset.shape[0])-i)
        holdout_loss/= (holdout_dataset.shape[0]+1)
        if((best-holdout_loss)/best>0.01):
            epoch_since_update = 0
            best = holdout_loss
        else:
            epoch_since_update+=1
        if(epoch_since_update>max_epoch_since_update):
            break
        

    return loss_val.item() 





def snippet_loss_validation_preferences_weighted(div: str, agent_samples, expert_samples, preference_dataset ,reward_func, reward_optimizer,  device, regularization='vanilla', epochs=1):
    ''' NOTE: only for maxentirl (FKL in trajectory): E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    assert div in ['snippet-validation-preferences-weighted']
    lamda = 1.0
    snippet_length = 40
    scale = 10*snippet_length
    _, T, d = expert_samples.shape
    sE = torch.FloatTensor(expert_samples).reshape(-1,T, d)
    if regularization=='trex':
        loss_fn = torch.nn.CrossEntropyLoss()
    else:
        loss_fn = torch.nn.MSELoss()


   # add preferences to the dataset
    pf_dataset_low = None
    pf_dataset_high = None
    pf_dataset_low_done = None
    pf_data_high_done = None
    
    for i in range(preference_dataset['levels']-1):
        pf1 = preference_dataset['state'][i*preference_dataset['episode_per_levels']:((i+1)*preference_dataset['episode_per_levels']),:,:]
        pf2 = preference_dataset['state'][(i+1)*preference_dataset['episode_per_levels']:((i+2)*preference_dataset['episode_per_levels']),:,:]

        pf1_done = preference_dataset['done'][i*preference_dataset['episode_per_levels']:((i+1)*preference_dataset['episode_per_levels']),:,:]
        pf2_done = preference_dataset['done'][(i+1)*preference_dataset['episode_per_levels']:((i+2)*preference_dataset['episode_per_levels']),:,:]

        if (pf_dataset_low) is None:
            pf_dataset_low = torch.FloatTensor(pf1)
            pf_dataset_high = torch.FloatTensor(pf2)
            pf_dataset_low_done = torch.FloatTensor(pf1_done)
            pf_dataset_high_done = torch.FloatTensor(pf2_done)
        else:
            pf_dataset_low = torch.cat((pf_dataset_low,torch.FloatTensor(pf1)),dim=0)
            pf_dataset_high = torch.cat((pf_dataset_high,torch.FloatTensor(pf2)),dim=0)
            pf_dataset_low_done = torch.cat((pf_dataset_low_done,torch.FloatTensor(pf1_done)),dim=0)
            pf_dataset_high_done = torch.cat((pf_dataset_high_done,torch.FloatTensor(pf2_done)),dim=0)
        
        # pf_dataset_low = np.concatenate((pf_dataset_low,pf1),axis=0)
        # pf_dataset_high = np.concatenate((pf_dataset_high,pf2), axis=0)

        # if(i==preference_dataset['levels']-2): # grounding with respect to expert
        #      pf_dataset_low = torch.cat((pf_dataset_low,torch.FloatTensor(pf2)),dim=0)
        #      pf_dataset_high = torch.cat((pf_dataset_high,torch.FloatTensor(sE[np.random.choice(sE.shape[0],size=pf2.shape[0]),:,:])), dim=0)
    
    intermediate_samples = [0,1]
    traj_dataset_pref = None
    label_dataset_pref = None
    for mag  in intermediate_samples:
        sM = pf_dataset_low + mag*(pf_dataset_high-pf_dataset_low)
        sM_vec = sM
        if(traj_dataset_pref is None):
            traj_dataset_pref =sM_vec
            label_dataset_pref = np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
            elif (regularization=='linear' or regularization=='trex'):
                if(mag>1.0):
                    val = -scale + (2.0-mag)*2*scale
                else:
                    val = -scale + (mag*2*scale)
            label_dataset_pref*=int(val)
            # print(val)
        else:
            traj_dataset_pref = np.concatenate((traj_dataset_pref,sM_vec),axis=0)
            labels =np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
                
            elif (regularization=='linear' or regularization=='trex'):
                if(mag>1.0):
                    val = -scale + (2.0-mag)*2*scale
                else:
                    val = -scale + (mag*2*scale)
            # print(val)
            labels*=int(val)
            label_dataset_pref = np.concatenate((label_dataset_pref,labels),axis=0)
    
    
    if epochs==-1:
        max_epoch_since_update = 5
        best = None
    else:
        max_epoch_since_update = epochs
        best = np.inf
    
    epoch_since_update = 0
 
    holdout_loss = 0
    pbar = tqdm.trange(epochs, desc="Training reward")
    for epoch in pbar: #tqdm(itertools.count()):
        idx = np.arange(traj_dataset_pref.shape[0])
        np.random.shuffle(idx)
        batch_size = 1024
        for i in range(0,idx.shape[0],batch_size):
            if lamda>0:
                # pref_sample_idx = np.random.choice(traj_dataset_pref.shape[0],size=(batch_size))
                # pref_train_x = traj_dataset_pref[pref_sample_idx ]
                # pref_label_x = label_dataset_pref[pref_sample_idx ][:,0:1]
                if regularization=='trex':
                    pf1_idx =   np.random.choice(pf_dataset_low.shape[0],size=(batch_size))
                    pf2_idx = pf1_idx
                    for j in range(pf2_idx.shape[0]):
                        pf2_idx[j] = np.random.randint(pf1_idx[j],pf_dataset_high.shape[0])
                    pref_train_x = pf_dataset_low[pf1_idx]
                    pref_train_y = pf_dataset_high[pf2_idx]
                    pref_done_x = pf_dataset_low_done[pf1_idx]
                    pref_done_y = pf_dataset_high_done[pf2_idx]
                    
                    snippet_pref_train_x = None
                    snippet_pref_train_y = None
                    snippet_pref_train_x_done = None
                    snippet_pref_train_y_done = None
                    for j in range(batch_size):
                        time_idx = np.random.randint(pref_train_x.shape[1]-snippet_length)
                        time_idx2 = np.random.randint(pref_train_y.shape[1]-snippet_length)
                        if snippet_pref_train_x is None:
                            snippet_pref_train_x =pref_train_x[j:j+1,time_idx:time_idx+snippet_length,:]
                            snippet_pref_train_y = pref_train_y[j:j+1,time_idx2:time_idx2+snippet_length,:]
                            snippet_pref_train_x_done =pref_done_x[j:j+1,time_idx:time_idx+snippet_length,:]
                            snippet_pref_train_y_done = pref_done_y[j:j+1,time_idx2:time_idx2+snippet_length,:]
                        else:
                            snippet_pref_train_x = np.concatenate((snippet_pref_train_x,pref_train_x[j:j+1,time_idx:time_idx+snippet_length,:]),axis=0)
                            snippet_pref_train_y = np.concatenate((snippet_pref_train_y,pref_train_y[j:j+1,time_idx2:time_idx2+snippet_length,:]),axis=0)
                            snippet_pref_train_x_done = np.concatenate((snippet_pref_train_x_done,pref_done_x[j:j+1,time_idx:time_idx+snippet_length,:]),axis=0)
                            snippet_pref_train_y_done = np.concatenate((snippet_pref_train_y_done,pref_done_y[j:j+1,time_idx2:time_idx2+snippet_length,:]),axis=0)
                    
                    pref_t1 = (reward_func.r(torch.FloatTensor(snippet_pref_train_x.reshape(-1,d)).to(device)).view(batch_size,snippet_length)*torch.FloatTensor(1-snippet_pref_train_x_done)).sum(1).view(-1,1)
                    pref_t2 = (reward_func.r(torch.FloatTensor(snippet_pref_train_y.reshape(-1,d)).to(device)).view(batch_size,snippet_length)*torch.FloatTensor(1-snippet_pref_train_y_done)).sum(1).view(-1,1)
                    pref_train_final = torch.zeros((batch_size,2))
                    pref_label_x = np.zeros((batch_size,1))
                    for j in range(batch_size):
                        if np.random.uniform()>0.5:
                            pref_train_final[j,0]=pref_t1[j]
                            pref_train_final[j,1]=pref_t2[j]
                            pref_label_x[j,0] = 1
                        else:
                            pref_train_final[j,1]=pref_t1[j]
                            pref_train_final[j,0]=pref_t2[j]
                            pref_label_x[j,0] = 0
                    # import ipdb;ipdb.set_trace()
                    loss_val=(lamda)* loss_fn(pref_train_final,torch.FloatTensor(pref_label_x).view(-1).to(device).long())
                    

                else:     
                    pref_sample_idx = np.random.randint(0,pf_dataset_low.shape[0],size=(batch_size//2))
                    pref_sample_high_idx = pref_sample_idx
                    for j in range(pref_sample_high_idx.shape[0]):
                        pref_sample_high_idx[j] = np.random.randint(pref_sample_idx[j],pf_dataset_high.shape[0])
                    pref_train_x = pf_dataset_low[pref_sample_idx ]
                    # pref_train_x = np.concatenate((pref_train_x,pf_dataset_high[pref_sample_idx ]),axis=0)
                    pref_train_x = np.concatenate((pref_train_x,pf_dataset_high[pref_sample_high_idx ]),axis=0)
                    pref_done_x = np.concatenate((pf_dataset_low_done[pref_sample_idx],pf_dataset_high_done[pref_sample_high_idx ]),axis=0)
                    pref_label_x = np.zeros((batch_size,1))
                    # import ipdb;ipdb.set_trace()
                    pref_label_x[:batch_size//2]=(-scale+2*scale*pref_sample_idx/pf_dataset_low.shape[0]).reshape(-1,1)
                    # pref_label_x[batch_size//2:]=(-scale+2*scale*(pref_sample_idx+1)/pf_dataset_low.shape[0]).reshape(-1,1)
                    pref_label_x[batch_size//2:]=(-scale+2*scale*(pref_sample_high_idx+1)/pf_dataset_low.shape[0]).reshape(-1,1)

                    snippet_pref_train_x = None
                    snippet_pref_done_x = None
                    for j in range(batch_size):
                        time_idx = np.random.randint(pref_train_x.shape[1]-snippet_length)

                        if snippet_pref_train_x is None:
                            snippet_pref_train_x = pref_train_x[j:j+1,time_idx:time_idx+snippet_length,:]
                            snippet_pref_done_x =  pref_done_x[j:j+1,time_idx:time_idx+snippet_length,:]
                        else:
                            snippet_pref_train_x = np.concatenate((snippet_pref_train_x,pref_train_x[j:j+1,time_idx:time_idx+snippet_length,:]),axis=0)
                            snippet_pref_done_x = np.concatenate((snippet_pref_done_x,pref_done_x[j:j+1,time_idx:time_idx+snippet_length,:]),axis=0)
                    # import ipdb;ipdb.set_trace()
                    pref_t1 = (reward_func.r(torch.FloatTensor(snippet_pref_train_x.reshape(-1,d)).to(device)).view(batch_size,snippet_length)*torch.FloatTensor(1-snippet_pref_done_x).view(batch_size,snippet_length).to(device)).sum(1).view(-1,1)

                    loss_val=(lamda)*loss_fn(pref_t1,torch.FloatTensor(pref_label_x).to(device))
            reward_optimizer.zero_grad()
            loss_val.backward()
            reward_optimizer.step()
        holdout_loss = 0
        batch_ctr = 0

        # Print rewards of all preference trajectories
        # import ipdb;ipdb.set_trace()
        returns = reward_func.r(torch.FloatTensor(traj_dataset_pref.reshape(-1,d)).to(device)).view(traj_dataset_pref.shape[0],traj_dataset_pref.shape[1]).sum(1).detach().cpu().numpy()
        print("Preference returns: {}".format(returns[:returns.shape[0]//2]))
        converged = True
        incorrect_pairs = 0
        for k in range(returns.shape[0]//2 -1):
            if(returns[k+1]-returns[k]<0):
                converged=False
                incorrect_pairs+=1
        if converged:
            print("Converged++++++++++++++++++++")
        else:
            print("Not converged---------------- Incorrect pairs: {}".format(incorrect_pairs))
        if((best-holdout_loss)/best>0.01):
            epoch_since_update = 0
            best = holdout_loss
        else:
            epoch_since_update+=1
        if(epoch_since_update>max_epoch_since_update):
            break
    return loss_val.item() 


def maxentirl_loss_sigmoid_validation_preferences_weighted(div: str, agent_samples, expert_samples, preference_dataset ,reward_func, reward_optimizer,  device, regularization='vanilla', epochs=1):
    ''' NOTE: only for maxentirl (FKL in trajectory): E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    assert div in ['pal-preferences-weighted']
    lamda = 0.3
    snippet_length=1
    scale = snippet_length*10
    sA, _, _ = agent_samples
    _, T, d = sA.shape

    sA = torch.FloatTensor(sA)
    sE = torch.FloatTensor(expert_samples).reshape(-1,T, d)
    loss_fn = torch.nn.MSELoss()
    

    intermediate_samples = np.arange(0.0,1.1,0.2) #np.arange(-0.4,1.4,0.2)
    sigmoid_range = 6

    # Generate dataset
    traj_dataset = None
    label_dataset = None
    for mag  in intermediate_samples:
        # import ipdb;ipdb.set_trace()
        # sE_sample = sE[np.random.choice(min(sA.shape[0],sE.shape[0]))]
        sM = sA + mag*(sE[np.random.choice(sE.shape[0],size=sA.shape[0]),:,:]-sA)
        sM_vec = sM
        if(traj_dataset is None):
            traj_dataset =sM_vec
            label_dataset = np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            label_dataset*=int(val)
            # print(val)
        else:
            traj_dataset = np.concatenate((traj_dataset,sM_vec),axis=0)
            labels =np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
                
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            # print(val)
            labels*=int(val)
            label_dataset = np.concatenate((label_dataset,labels),axis=0)


   # add preferences to the dataset
    pf_dataset_low = None
    pf_dataset_high = None
    
    for i in range(preference_dataset['levels']-1):
        pf1 = preference_dataset['state'][i*preference_dataset['episode_per_levels']:((i+1)*preference_dataset['episode_per_levels']),:,:]
        pf2 = preference_dataset['state'][(i+1)*preference_dataset['episode_per_levels']:((i+2)*preference_dataset['episode_per_levels']),:,:]
        # print(pf_dataset_low.shape, pf_dataset_high.shape)
        # import ipdb;ipdb.set_trace()
        if (pf_dataset_low) is None:
            pf_dataset_low = torch.FloatTensor(pf1)
            pf_dataset_high = torch.FloatTensor(pf2)
        else:
            pf_dataset_low = torch.cat((pf_dataset_low,torch.FloatTensor(pf1)),dim=0)
            pf_dataset_high = torch.cat((pf_dataset_high,torch.FloatTensor(pf2)),dim=0)
        
        # pf_dataset_low = np.concatenate((pf_dataset_low,pf1),axis=0)
        # pf_dataset_high = np.concatenate((pf_dataset_high,pf2), axis=0)

        if(i==preference_dataset['levels']-2): # grounding with respect to expert
             pf_dataset_low = torch.cat((pf_dataset_low,torch.FloatTensor(pf2)),dim=0)
             pf_dataset_high = torch.cat((pf_dataset_high,torch.FloatTensor(sE[np.random.choice(sE.shape[0],size=pf2.shape[0]),:,:])), dim=0)
    
    intermediate_samples = [0,1]
    # traj_dataset_pref = None
    # label_dataset_pref = None
    # for mag  in intermediate_samples:
    #     # import ipdb;ipdb.set_trace()
    #     # sE_sample = sE[np.random.choice(min(sA.shape[0],sE.shape[0]))]
    #     sM = pf_dataset_low + mag*(pf_dataset_high-pf_dataset_low)
    #     sM_vec = sM
    #     if(traj_dataset_pref is None):
    #         traj_dataset_pref =sM_vec
    #         label_dataset_pref = np.ones((sM_vec.shape[0],sM_vec.shape[1]))
    #         if(regularization=='sigmoid'):
    #             val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
    #         elif (regularization=='linear'):
    #             if(mag>1.0):
    #                 val = -10 + (2.0-mag)*20
    #             else:
    #                 val = -10 + (mag*20)
    #         label_dataset_pref*=int(val)
    #         # print(val)
    #     else:
    #         traj_dataset_pref = np.concatenate((traj_dataset_pref,sM_vec),axis=0)
    #         labels =np.ones((sM_vec.shape[0],sM_vec.shape[1]))
    #         if(regularization=='sigmoid'):
    #             val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
                
    #         elif (regularization=='linear'):
    #             if(mag>1.0):
    #                 val = -10 + (2.0-mag)*20
    #             else:
    #                 val = -10 + (mag*20)
    #         # print(val)
    #         labels*=int(val)
    #         label_dataset_pref = np.concatenate((label_dataset_pref,labels),axis=0)
    
    # pref_state_dataset = traj_dataset_pref.reshape(-1,d)
    # pref_state_label_dataset = label_dataset_pref.reshape(-1)
    # pref_idx = np.arange(pref_state_dataset.shape[0])
    # pref_holdout_size = int(0.1*(pref_state_dataset.shape[0]))
    # pref_train_dataset = pref_state_dataset[pref_idx[pref_holdout_size:],:]
    # pref_train_label = pref_state_label_dataset[pref_idx[pref_holdout_size:]]
    # pref_holdout_dataset = pref_state_dataset[pref_idx[:pref_holdout_size],:]
    # pref_holdout_label = pref_state_label_dataset[pref_idx[:pref_holdout_size]]



    state_dataset = traj_dataset.reshape(-1,d)
    state_label_dataset = label_dataset.reshape(-1)
    idx = np.arange(state_dataset.shape[0])
    holdout_size = int(0.1*(state_dataset.shape[0]))
    train_dataset = state_dataset[idx[holdout_size:],:]
    train_label = state_label_dataset[idx[holdout_size:]]
    holdout_dataset = state_dataset[idx[:holdout_size],:]
    holdout_label = state_label_dataset[idx[:holdout_size]]

    max_epoch_since_update = 5
    epoch_since_update = 0
    best = np.inf
    mode = "consecutive_linear" # "consecutive_linear", "next_random_linear", "next_random_max"
    for epoch in itertools.count():
        idx = np.arange(train_dataset.shape[0])
        np.random.shuffle(idx)
        batch_size = 1024
        for i in range(0,idx.shape[0],batch_size):
            train_x = train_dataset[idx[i:min(i+batch_size,idx.shape[0])]]
            label_x = train_label[idx[i:min(i+batch_size,idx.shape[0])]].reshape(-1,1)
            t1 = reward_func.r(torch.FloatTensor(train_x).to(device))
            pref_sample_level_idx = np.random.randint(0,preference_dataset['levels']-1,size=(batch_size//2))
            pref_sample_idx = pref_sample_level_idx+np.random.randint(0,preference_dataset['episode_per_levels'],size=(batch_size//2))
            
            pref_sample_level_idx2 = pref_sample_level_idx
            pref_sample_idx2 = pref_sample_idx
            for j in range(pref_sample_idx2.shape[0]):
                if "consecutive" in mode:
                    pref_sample_level_idx2[j]= pref_sample_level_idx[j]
                else:
                    pref_sample_level_idx2[j]=np.random.randint(pref_sample_level_idx[j],preference_dataset['levels']-1)
                
                pref_sample_idx2[j] = pref_sample_level_idx2[j]+np.random.randint(0,preference_dataset['episode_per_levels'])
            
            pref_train_x = pf_dataset_low[pref_sample_idx ]
            pref_train_x = np.concatenate((pref_train_x,pf_dataset_high[pref_sample_idx2]),axis=0)
            pref_label_x = np.zeros((batch_size,1))
            # import ipdb;ipdb.set_trace()
            pref_label_x[:batch_size//2]=(-scale+2*scale*pref_sample_level_idx/(preference_dataset['levels']-1)).reshape(-1,1)
            pref_label_x[batch_size//2:]=(-scale+2*scale*(pref_sample_level_idx2+1)/(preference_dataset['levels']-1)).reshape(-1,1)
                
            snippet_pref_train_x = None
            for j in range(batch_size):
                time_idx = np.random.randint(pref_train_x.shape[1]-snippet_length)

                if snippet_pref_train_x is None:
                    snippet_pref_train_x =pref_train_x[j:j+1,time_idx:time_idx+snippet_length,:]
                else:
                    snippet_pref_train_x = np.concatenate((snippet_pref_train_x,pref_train_x[j:j+1,time_idx:time_idx+snippet_length,:]),axis=0)

            pref_t1 = reward_func.r(torch.FloatTensor(snippet_pref_train_x.reshape(-1,d)).to(device)).view(batch_size,snippet_length).sum(1).view(-1,1)
            
            if 'max' in mode:
                pref_loss = (pref_label_x[:batch_size//2]-pref_label_x[batch_size//2:]).mean()
            else:
                pref_loss = loss_fn(pref_t1,torch.FloatTensor(pref_label_x).to(device))
            # pref_sample_idx = np.random.choice(traj_dataset_pref.shape[0],size=(batch_size))
            # pref_train_x = pref_train_dataset[pref_sample_idx]
            # pref_label_x = pref_train_label[pref_sample_idx ].reshape(-1,1)
            
            # # import ipdb;ipdb.set_trace()
            
            # pref_t1 = reward_func.r(torch.FloatTensor(pref_train_x).to(device))
            
            loss_val= (1-lamda)*loss_fn(t1,torch.FloatTensor(label_x).to(device))+(lamda)*pref_loss
            
            reward_optimizer.zero_grad()
            loss_val.backward()
            reward_optimizer.step()
        holdout_loss = 0
        batch_ctr = 0
        for i in range(0,holdout_dataset.shape[0],batch_size):
            holdout_x = holdout_dataset[i:min(i+batch_size,holdout_dataset.shape[0])]
            holdout_label_x = holdout_label[i:min(i+batch_size,holdout_dataset.shape[0])].reshape(-1,1)
            t1 = reward_func.r(torch.FloatTensor(holdout_x).to(device))
            # pref_sample_idx = np.random.choice(pref_holdout_dataset.shape[0],size=(batch_size))
            # pref_holdout_x = pref_holdout_dataset[ pref_sample_idx]
            # pref_holdout_label_x = pref_holdout_label[pref_sample_idx].reshape(-1,1)
            
            
            # pref_t1 = reward_func.r(torch.FloatTensor(pref_holdout_x).to(device))
            loss_val= (1-lamda)*loss_fn(t1,torch.FloatTensor(holdout_label_x).to(device))#+lamda*loss_fn(pref_t1,torch.FloatTensor(pref_holdout_label_x).to(device))
            holdout_loss+= loss_val.item()*(min(i+batch_size,holdout_dataset.shape[0])-i)
        holdout_loss/= (holdout_dataset.shape[0]+1)
        if((best-holdout_loss)/best>0.01):
            epoch_since_update = 0
            best = holdout_loss
        else:
            epoch_since_update+=1
        if(epoch_since_update>max_epoch_since_update):
            break
        

    return loss_val.item() 






def contrastive_maxentirl_sigmoid_loss_validation(div: str, cum_agent_samples, expert_samples, reward_func, reward_optimizer, device, regularization='vanilla', epochs=1):
    ''' NOTE: only for maxentirl (FKL in trajectory): E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    assert div in ['rank-ral']
    sA = cum_agent_samples
    _, T, d = sA.shape

    # sA = torch.FloatTensor(sA)
    # sE = torch.FloatTensor(expert_samples).reshape(-1,T, d)
    sE = expert_samples.reshape(-1,T,d)
    
    loss_fn = torch.nn.MSELoss()

    intermediate_samples = np.arange(0.0,1.1,0.2) #np.arange(-0.4,1.4,0.2)
    sigmoid_range = 6

    # Generate dataset
    traj_dataset = None
    label_dataset = None
    for mag  in intermediate_samples:
        sE_sample = sE[np.random.choice(sE.shape[0],size=sA.shape[0])]
        sM = sA + mag*(sE_sample-sA)
        # import ipdb;ipdb.set_trace()
        sM_vec = sM
        if(traj_dataset is None):
            traj_dataset =sM_vec
            label_dataset = np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            elif 'exp' in regularization:
                split_words = regularization.split('-')
                if 'n' in split_words[-1]:
                    slope = -int(split_words[-1][1:])
                else:
                    slope = int(split_words[-1])
                p = 20/(np.exp(slope)-1)
                q = -10*(1+np.exp(slope))/(np.exp(slope)-1)
                val = q+p*np.exp(slope*mag)
            label_dataset*=int(val)
            # print(val)
        else:
            traj_dataset = np.concatenate((traj_dataset,sM_vec),axis=0)
            labels =np.ones((sM_vec.shape[0],sM_vec.shape[1]))
            if(regularization=='sigmoid'):
                val = np_sigmoid(-sigmoid_range+sigmoid_range*mag*2)*10-5
                
            elif (regularization=='linear'):
                if(mag>1.0):
                    val = -10 + (2.0-mag)*20
                else:
                    val = -10 + (mag*20)
            elif 'exp' in regularization:
                split_words = regularization.split('-')
                if 'n' in split_words[-1]:
                    slope = -int(split_words[-1][1:])
                else:
                    slope = int(split_words[-1])
                # slope = int(split_words[-1])
                
                p = 20/(np.exp(slope)-1)
                q = -10*(1+np.exp(slope))/(np.exp(slope)-1)
                val = q+p*np.exp(slope*mag)
            # print(val)
            labels*=int(val)
            label_dataset = np.concatenate((label_dataset,labels),axis=0)

    # print("Cum array size: {}".format((cum_agent_samples.size*cum_agent_samples.itemsize)/ (1024 ** 2)))
    # print("Memory usage after dataset creation: ", process.memory_info().rss/ (1024 ** 2))

    
    state_dataset = traj_dataset.reshape(-1,d)
    state_label_dataset = label_dataset.reshape(-1)
    idx = np.arange(state_dataset.shape[0])
    holdout_size = int(0.1*(state_dataset.shape[0]))
    train_dataset = state_dataset[idx[holdout_size:],:]
    train_label = state_label_dataset[idx[holdout_size:]]

    holdout_dataset = state_dataset[idx[:holdout_size],:]
    holdout_label = state_label_dataset[idx[:holdout_size]]

    if epochs==-1:
        max_epoch_since_update = 5
        best = None
    else:
        max_epoch_since_update = epochs
        best = np.inf
    
    epoch_since_update = 0
 
    
    pbar = tqdm.trange(100, desc="Training reward")

    # max_epoch_since_update = 5
    # epoch_since_update = 0
    # best = np.inf #None
    batch_size = 1024
    for epoch in pbar:
        idx = np.arange(train_dataset.shape[0])
        np.random.shuffle(idx)
        
        for i in range(0,idx.shape[0],batch_size):
            train_x = train_dataset[idx[i:min(i+batch_size,idx.shape[0])]]
            label_x = train_label[idx[i:min(i+batch_size,idx.shape[0])]].reshape(-1,1)
            # import ipdb;ipdb.set_trace()
            t1 = reward_func.r(torch.FloatTensor(train_x).to(device))
            loss_val= loss_fn(t1,torch.FloatTensor(label_x).to(device))
            reward_optimizer.zero_grad()
            
            # print("Memory usage train1: ", process.memory_info().rss/ (1024 ** 2))
            loss_val.backward()
            reward_optimizer.step()
            # print("Memory usage train2: ", process.memory_info().rss/ (1024 ** 2))
        holdout_loss = 0
        batch_ctr = 0
        # with torch.no_grad():
        for i in range(0,holdout_dataset.shape[0],batch_size):
            holdout_x = holdout_dataset[i:min(i+batch_size,holdout_dataset.shape[0])]
            holdout_label_x = holdout_label[i:min(i+batch_size,holdout_dataset.shape[0])].reshape(-1,1)
            t1 = reward_func.r(torch.FloatTensor(holdout_x).to(device))
            
            loss_val= loss_fn(t1,torch.FloatTensor(holdout_label_x).to(device))
            # print("Memory usage holdout1: ", process.memory_info().rss/ (1024 ** 2))
            holdout_loss+= loss_val.item()*(min(i+batch_size,holdout_dataset.shape[0])-i)
            # print("Memory usage holdout2: ", process.memory_info().rss/ (1024 ** 2))
        holdout_loss/= (holdout_dataset.shape[0]+1)
        pbar.set_description("Holdout loss: {}".format(holdout_loss))
        if (best is None or ((best-holdout_loss)/best>0.01)):
            epoch_since_update = 0
            best = holdout_loss
        else:
            epoch_since_update+=1
        if(epoch_since_update>max_epoch_since_update):
            break
    
    torch.cuda.empty_cache()

    return 0, loss_val.item() 