import os
import sys
import time
from datetime import datetime
import shutil
import math
import pickle

import numpy as np
import torch
import random
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Categorical

def disagreement_sampling(args, all_pref_preds):
    label_str = ['negative', 'positive', 'neutral']

    with open('./pre_gen/{}_human_pref.pkl'.format(args.dataset), 'rb') as f:
        human_preference_all = pickle.load(f)

    with open('./pre_gen/{}_indices.pkl'.format(args.dataset), 'rb') as f:
        indices_all = pickle.load(f)

    selected_idx = torch.zeros(len(all_pref_preds[0]), 2).long()

    for label in label_str:
        candidate_idx = indices_all[label]
        human_preference_label = human_preference_all[label]

        all_prob_pref = []
        for pref in all_pref_preds:
            exp_pref = torch.exp(pref[candidate_idx])
            denom = exp_pref.unsqueeze(1) + exp_pref.unsqueeze(0)
            prob_pref = exp_pref.unsqueeze(1) / denom
            all_prob_pref.append(prob_pref.unsqueeze(0))
        all_prob_pref = torch.cat(all_prob_pref, dim=0)

        for i in range(len(candidate_idx)):
            orig_idx = candidate_idx[i]
            prob_prefs = all_prob_pref[:, i, :]

            std = prob_prefs.std(dim=0)
            selected = std.argmax()

            selected_idx[orig_idx, 0] = candidate_idx[selected]
            selected_idx[orig_idx, 1] = human_preference_label[i, selected]

    return selected_idx[:, 0], selected_idx[:, 1]

def inconsistency_sampling(args, all_pref_preds, all_probs_soft, labels):
    label_str = ['negative', 'positive', 'neutral']

    with open('./pre_gen/{}_human_pref.pkl'.format(args.dataset), 'rb') as f:
        human_preference_all = pickle.load(f)

    with open('./pre_gen/{}_indices.pkl'.format(args.dataset), 'rb') as f:
        indices_all = pickle.load(f)

    if args.ensemble:
        selected_idx = torch.zeros(len(all_pref_preds[0]), 2).long()
    else:
        selected_idx = torch.zeros(len(all_pref_preds), 2).long()

    soft_labels = torch.Tensor(np.load('./pre_gen/{}_soft_label.npy'.format(args.dataset)))

    all_probs = all_probs_soft[torch.arange(len(soft_labels)), labels]

    for label in label_str:
        indices_label = indices_all[label]
        prob = all_probs[indices_label]

        prob_delta = (all_probs_soft[indices_label].unsqueeze(1) - all_probs_soft[indices_label].unsqueeze(0))  # N x N x K
        soft_labels_delta = (soft_labels[indices_label].unsqueeze(1) - soft_labels[indices_label].unsqueeze(0))

        # Due to memory issue, considering element-wise operation instead of matrix-level
        for i in range(len(prob_delta)):
            zeros = torch.zeros(len(prob_delta)).float()
            mask1, mask2 = (soft_labels_delta[i] >= 0).float(), (soft_labels_delta[i] < 0).float()
            loss_delta = (mask1 * torch.max(zeros.unsqueeze(1), soft_labels_delta[i] - prob_delta[i])).sum(dim=-1)
            loss_delta += (mask2 * torch.max(zeros.unsqueeze(1), prob_delta[i] - soft_labels_delta[i])).sum(dim=-1)

            loss_delta_norm = (loss_delta - loss_delta.min()) / (loss_delta.max() - loss_delta.min())
            dist = Categorical(loss_delta_norm)
            select_idx = dist.sample()

            selected_idx[indices_label[i], 0] = indices_label[select_idx]
            selected_idx[indices_label[i], 1] = human_preference_all[label][i, select_idx]
    return selected_idx[:, 0], selected_idx[:, 1]