import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler, SequentialSampler
from scipy.spatial.distance import jensenshannon
from scipy.stats import entropy
from ATC_helper import * # Average Thresholded Confidence challenger (Garg et al., ICLR 2022)
from mandoline import * # Mandoline challenger (Chen et al., ICML 2021)
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import normalize
from sklearn.neighbors import NearestNeighbors
from optimization import * # model training, evaluation
from sklearn.model_selection import train_test_split

######## utils #####################

def cls_softmax_representations(model, scaled_model, model_type, loader, device):
    """
    Ouputs CLS and Softmax representations for given model and data.

    Args:
    model (torch model)
    scaled_model (torch model)
    model_type (str)
    loader (torch DataLoader)
    device (str)

    Returns: 3 tensors, one for the CLS representation, another one for the Softmax, and the last one for the true labels.

    """
    # get CLS and softmax representations
    cls_list = []
    logits_list = []
    softmax_list = []
    scaled_softmax_list = []
    y_list = []

    for batch in loader:
        # evaluation mode
        model.eval()
        scaled_model.eval()
        # inputs and labels
        text = batch[0]
        categorical = batch[1]
        numerical = batch[2]
        y = batch[3]
        mask = batch[4]
        # to device
        text = text.to(device)
        mask = mask.to(device)
        categorical = categorical.to(device)
        numerical = numerical.to(device)
        y = y.to(device)
        y_list.append(y)
        # prediction
        with torch.no_grad():
            output = model(text, mask, categorical, numerical.float())
            scaled_output = scaled_model(text, mask, categorical, numerical.float())
        if model_type == "LateFuseBERT":
            pred = output[0]
            scaled_pred = scaled_output
            text_cls = output[1]
            tabular_cls = output[2]
            text_tab_cls = torch.cat((text_cls, tabular_cls), dim=1)
            cls_list.append(text_tab_cls)
            logits_list.append(pred)
        if model_type == "AllTextBERT":
            pred = output[0]
            scaled_pred = scaled_output
            cls = output[1]
            cls_list.append(cls)
            logits_list.append(pred)
        # compute softmax probabilities
        p_hat = F.softmax(pred, dim=1)
        scaled_p_hat = F.softmax(scaled_pred, dim=1)
        softmax_list.append(p_hat)
        scaled_softmax_list.append(scaled_p_hat)

    cls_list = torch.cat(cls_list)
    logits_list = torch.cat(logits_list)
    softmax_list = torch.cat(softmax_list)
    scaled_softmax_list = torch.cat(scaled_softmax_list)
    y_list = torch.cat(y_list)

    return cls_list, logits_list, softmax_list, scaled_softmax_list, y_list


def cls_softmax_cond_representations(cls_tensor, softmax_tensor, y_tensor, data_index, quantile_for_distance):
    """
    Returns CLS and Softmax representations, and true labels, in the neighborhood of a given data point with index data_index. This function is used to compute conditional performance.

    Args:
    cls_tensor (Tensor): CLS representation
    softmax_tensor (Tensor): Softmax representation
    y_tensor (Tensor): true labels
    data_index (int): index of data point
    quantile_for_distance (float): quantile used to define the threshold to define a neighbor.

    Returns: number of neighbors and 3 tensors (CLS, Softmax, Y).

    """

    data_cls = cls_tensor.clone()
    data_softmax = softmax_tensor.clone()
    data_y = y_tensor.clone()

    # comute L2distances between each datapoint
    dist_mat = torch.cdist(F.normalize(data_cls), F.normalize(data_cls), compute_mode = 'donot_use_mm_for_euclid_dist')
    # keep only (strictly) lower triangular matrix
    lower_triangular = torch.tril(dist_mat)
    lower_triangular = lower_triangular.flatten()
    # select threshold for distances (quantile)
    threshold = torch.quantile(lower_triangular[lower_triangular!=0], q=quantile_for_distance)
    # select indices closest, based on distance thresholds
    neighbor_indices = torch.where(dist_mat[data_index]<=threshold)[0]
    num_neighbors = neighbor_indices.shape[0]
    # extract the representations
    cond_cls = data_cls[neighbor_indices.unsqueeze(1), torch.arange(data_cls.size(1))]
    cond_softmax = data_softmax[neighbor_indices.unsqueeze(1), torch.arange(data_softmax.size(1))]
    cond_y = data_y[neighbor_indices]

    return num_neighbors, cond_cls, cond_softmax, cond_y

# Temperature scaling (Guo et al., 2017)

class ModelWithTemperature(nn.Module):
    """
    A thin decorator, which wraps a model with temperature scaling
    model (nn.Module):
        A classification neural network
        NB: Output of the neural network should be the classification logits,
            NOT the softmax (or log softmax)!
    """
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, text, mask, categorical, numerical):
        logits = self.model(text, mask, categorical, numerical.float())[0]
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        a =logits / temperature
        return logits / temperature

    # This function probably should live outside of this class, but whatever
    def set_temperature(self, valid_loader, device):
        """
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        valid_loader (DataLoader): validation set loader
        """
        self.to(device)
        nll_criterion = nn.CrossEntropyLoss().to(device)
        ece_criterion = _ECELoss().to(device)

        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []

        with torch.no_grad():
            for batch in valid_loader:
                self.model.eval()
                # inputs and labels
                text = batch[0]
                categorical = batch[1]
                numerical = batch[2]
                y = batch[3]
                mask = batch[4]
                # to device
                text = text.to(device)
                mask = mask.to(device)
                categorical = categorical.to(device)
                numerical = numerical.to(device)
                y = y.to(device)
                # prediction
                logits = self.model(text, mask, categorical, numerical.float())[0]
                logits_list.append(logits)
                labels_list.append(y)
            logits = torch.cat(logits_list).to(device)
            labels = torch.cat(labels_list).to(device)

        # Calculate NLL before temperature scaling
        # before_temperature_nll = nll_criterion(logits, labels).item()
        # print('Before temperature - NLL:',before_temperature_nll)
        # before_temperature_ece = ece_criterion(logits, labels).item()
        # print('Before temperature - ECE:',before_temperature_ece)

        # Next: optimize the temperature w.r.t. NLL
        optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)

        def eval():
            optimizer.zero_grad()
            # loss = nll_criterion(self.temperature_scale(logits), labels)
            loss = ece_criterion(self.temperature_scale(logits), labels)
            loss.backward()
            return loss
        optimizer.step(eval)

        # Calculate NLL after temperature scaling
        # after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
        # print('Optimal temperature: %.3f' % self.temperature.item())
        # print('After temperature - NLL:', after_temperature_nll)
        # after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
        # print('After temperature - ECE:', after_temperature_ece)

        return self

class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).

    The input to this loss is the logits of a model, NOT the softmax scores.

    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:

    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |

    We then return a weighted average of the gaps, based on the number
    of samples in each bin

    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece
        
def scaled_softmax_representations(model, loader, device):
    """
    Ouputs scaled Softmax representations for given model and data.

    Args:
    model (torch model)
    loader (torch DataLoader)
    device (str)

    Returns: 1 tensors for the scaled Softmax

    """
    # get CLS and softmax representations
    softmax_list = []

    for batch in loader:
        # evaluation mode
        model.eval()
        # inputs and labels
        text = batch[0]
        categorical = batch[1]
        numerical = batch[2]
        y = batch[3]
        mask = batch[4]
        # to device
        text = text.to(device)
        mask = mask.to(device)
        categorical = categorical.to(device)
        numerical = numerical.to(device)
        y = y.to(device)
        # prediction
        with torch.no_grad():
            pred = model(text, mask, categorical, numerical.float())

        # compute softmax probabilities
        p_hat = F.softmax(pred, dim=1)
        softmax_list.append(p_hat)

    softmax_list = torch.cat(softmax_list)

    return softmax_list
    
# model initialization
def define_nn_model(init_seed, input_shape, output_shape):
      """Create and initialize simple neural network."""
      torch.manual_seed(init_seed)
      model = nn.Sequential(nn.Linear(input_shape, input_shape),
                                          nn.ReLU(),
                                          nn.Dropout(0.1),
                                          nn.Linear(input_shape, output_shape))
      return model

# model training
def train_nn_model(epochs, seed, loader, model, criterion, optimizer, verbose, device):
    """Train simple neural network model."""
    # training on source data
    for epoch in range(1, epochs):
        start=time.time()
        train_loss = 0 # training loss by sample
        total = 0 # number of samples
        torch.manual_seed(seed)
        for batch  in loader:
            model.train()
            x = batch[0]
            y = batch[1]

            # 1. clear gradients
            optimizer.zero_grad()

            # 2. to device
            x = x.to(device)
            y = y.to(device)
            model = model.to(device)

            # 3. forward pass and compute loss
            y_hat = model(x)
            loss = criterion(y_hat,y)

            # 4. backward pass
            loss.backward()

            # 5. optimization
            optimizer.step()

            # 6. record loss
            train_loss += loss.item()*y.shape[0]
            total += y.shape[0]

        end=time.time()
        train_loss = train_loss/total
        if verbose:
            print("---------training time (s):", round(end-start,0), "---------")
            print("epoch:", epoch, "training loss:", round(train_loss,5))

    return model
######## methods #####################

# Jensen Shannon Distance (Lin, 1991)
def max_proba_jsd(src_proba, tgt_proba, num_bins = 10):
    """
    Computes the Jensen Shannon Distance between the distributions of maximum probabilities (Source vs Target).
    Args:
    src_proba (pandas dataframe): model Softmax outputs on Source dataset
    tgt_proba (pandas dataframe): model Softmax outputs on Target dataset
    num_bins (int) : number of bins to discretize the max probability distributions.
    """
    source_proba = src_proba.copy()
    target_proba = tgt_proba.copy()

    # extract maximum probability from Softmax representation
    max_proba_source = np.max(source_proba, axis = 1)
    max_proba_target = np.max(target_proba, axis = 1)

    # construct the distribution of max probabilities
    breaks = 100 * np.arange(0, num_bins +1)/num_bins
    percentiles = np.percentile(max_proba_source, breaks)
    percentiles[0] = 0 # starts at 0 and ends at 1
    percentiles[-1] = 1
    # compute histogram for Source and Target data
    source_pct = np.histogram(max_proba_source, percentiles)[0]/len(max_proba_source)
    target_pct = np.histogram(max_proba_target, percentiles)[0]/len(max_proba_target)

    jsd_metric = jensenshannon(source_pct, target_pct, base = 2.0)

    return jsd_metric
    
# Maximum Mean Discrepancy (Gretton et al., JMLR 2012)
def mmd(src, tgt, device):
    """
    Computes the Maximum Mean Discrepancy between two datasets (Source vs Target).
    Args:
    src (torch tensor): shape n x d
    tgt (torch tensor): shape m x d
    device: cpu or gpu

    Reference: A. Gretton et al.: A kernel two-sample test, JMLR 13 (2012)
    """
    x = src.clone()
    y = tgt.clone()
    x = F.normalize(x)
    y = F.normalize(y)

    n, d = x.shape
    m, d2 = y.shape
    assert d == d2
    # compute distance between each pair of datapoints
    xy = torch.cat([x.detach(), y.detach()], dim=0)
    xy = xy.to(device)
    dists = torch.cdist(xy, xy, p=2.0)
    # sigma parameter as median distance between points
    pdists = torch.pdist(torch.cat([x, y], dim=0))
    sigma = pdists.median()/2
    # RBF kernel
    M = torch.eye(n+m)*1e-5
    M = M.to(device)
    k = torch.exp((-1/(2*sigma**2)) * dists**2) + M
    k_x = k[:n, :n]
    k_y = k[n:, n:]
    k_xy = k[:n, n:]
    # MMD statistics
    mmd_stat = k_x.sum() / (n * (n - 1)) + k_y.sum() / (m * (m - 1)) - 2 * k_xy.sum() / (n * m)

    return mmd_stat

def p_val(src,tgt,mmd_stat,n_permutations,device, seed):
    """
    Computes p-value related to MMD statistic.
    Args:
    src (torch tensor): shape n x d
    tgt (torch tensor): shape m x d
    mmd_stat (torch tensor): original MMD statistic (before permutations)
    n_permutations (int): number of permutations
    device: cpu or gpu
    seed (int): reproducibility of permutations

    """
    x = src.clone()
    y = tgt.clone()


    torch.manual_seed(seed)
    n = len(x)
    xy = torch.cat([x, y], dim=0).double()
    mmds = []
    for i in range(n_permutations):
        xy = xy[torch.randperm(len(xy))]
        mmds.append(mmd(xy[:n], xy[n:], device).item())
    mmds = torch.tensor(mmds)
    mmds = mmds.to(device)
    p_val = (mmd_stat < mmds).float().mean()

    return p_val
    
# Average confidence
def average_confidence(target_data, var_list):
    """
    We estimate the Target error rate as 1 minus the average confidence.
    """
    target = target_data.copy()

    target["one_minus_max_proba"] = 1 - np.max(target[var_list].values, axis=1)
    mean_uncertainty = np.mean(target["one_minus_max_proba"])
    return target, mean_uncertainty

# Difference Of Confidence (DOC) (Guillory et al., ICCV 2021)
def doc(model, loader, model_type, source_proba, target_proba, seed, device):
    """
    We estimate the Target error rate by substracting difference of confidences on Source and Target (Guillory et al. (2021))
    """
    source_error_rate = 1 - performance_pretrained(model, loader, model_type, seed, device)
    source_confidence = source_proba.max(axis=1).mean()
    target_confidence = target_proba.max(axis=1).mean()
    target_error_rate = source_error_rate + source_confidence - target_confidence

    return target_error_rate
    
# Average Thresholded Confidence (ATC) (Garg et al., 2022)
def atc(src_y, src_proba, tgt_proba):
    """
    We estimate the Target error rate with the Average Thresholded Confidence
    """
    source_y = src_y.copy()
    source_proba = src_proba.copy()
    target_proba = tgt_proba.copy()

    source_labels = source_y.values
    source_scores = source_proba.max(axis=1).values
    source_preds = np.argmax(source_proba,axis=1)
    target_scores = target_proba.max(axis=1).values
    _, ATC_thres = find_ATC_threshold(source_scores, source_labels == source_preds)
    ATC_accuracy = get_ATC_acc(ATC_thres, target_scores)/100

    return 1 - ATC_accuracy
    
# Mandoline (Chen et al., ICML 2021)
def get_correct(preds, labels):
    """
    Returns whether the model makes a correct prediction.
    """
    correct = (preds==labels).values
    return correct[:, np.newaxis]  # (n, 1) binary np.ndarray

def mandoline_performance(src_y, src_proba, tgt_proba):
    """
    Error rate estimation with Mandoline challenger.
    """
    source_y = src_y.copy()
    source_proba = src_proba.copy()
    target_proba = tgt_proba.copy()

    # Empirical observations for the source data: a simple average would equal source accuracy
    source_preds = np.argmax(source_proba, axis=1)
    source_labels = source_y
    empirical_mat_list_src = [get_correct(source_preds, source_labels)]

    # estimate performance wit Mandoline framework (Importance Weighting)
    D_src = np.max(source_proba, axis=1).values[:, np.newaxis]
    D_tgt = np.max(target_proba, axis=1).values[:, np.newaxis]

    est_perf = estimate_performance(D_src, D_tgt, None, empirical_mat_list_src).all_estimates[0].weighted[0]

    return 1 - est_perf
    
# Monte Carlo Dropout
def enable_dropout(model):
    """ Function to enable the dropout layers during test-time """
    for m in model.modules():
        if m.__class__.__name__.startswith('Dropout'):
            m.train()

def compute_MCD(model, loader, n_simu, seed, device):
    """Compute MCD simulations: total and aleatoric uncertainty"""

    # seed (for reproducibility)
    np.random.seed(seed)
    seed_list = np.random.randint(0, 1000, n_simu)

    for i, simu_seed in enumerate(seed_list):

        softmax_list = [] # store softmax
        torch.manual_seed(simu_seed) # set seed

        for batch in loader:

            # training mode (dropout)
            model.eval()
            enable_dropout(model)

            # inputs and labels
            text = batch[0]
            categorical = batch[1]
            numerical = batch[2]
            y = batch[3]
            mask = batch[4]

            # to device
            text = text.to(device)
            mask = mask.to(device)
            categorical = categorical.to(device)
            numerical = numerical.to(device)
            y = y.to(device)

            # prediction
            with torch.no_grad():
              pred = model(text, mask, categorical, numerical.float())[0]

            # compute softmax probabilities
            p_hat = F.softmax(pred, dim=1)

            # store softmax distributions
            softmax_list.append(p_hat)

        if i == 0:
          softmax_list_sum = torch.cat(softmax_list)
        else:
          softmax_list_sum += torch.cat(softmax_list)

    # compute average over iterations
    softmax_list_avg = softmax_list_sum/(i+1)

    # compute Shannon entropy
    shannon_entropy = entropy(softmax_list_avg.cpu().numpy(), base=2, axis=1)
    shannon_entropy_mean = shannon_entropy.mean()

    return shannon_entropy, shannon_entropy_mean

def domain_classifier(source_data, target_data, var_list, seed):
  """
  Fit the Domain Classifier with a defined algorithm and returns AUROC.

  Args:
  source_data (pandas dataframe): Source dataset
  target_data (pandas dataframe): Target dataset
  var_list (list of str): list of features used in the datasets
  seed (int): seed
  """
  source = source_data.copy()
  target = target_data.copy()

  auroc_list = []

  # we halve the Source and Target datasets
  train1, test1 = train_test_split(source, test_size=0.5, random_state=seed)
  train2, test2 = train_test_split(target, test_size=0.5, random_state=seed)

  # we label the Source domain as 0 and Target domain as 1
  train1["domain"] = 0
  test1["domain"] = 0
  train2["domain"] = 1
  test2["domain"] = 1

  # concatenate training and test sets
  domain_train = pd.concat([train1, train2])
  domain_test = pd.concat([test1, test2])

  ## First fit
  # fit domain classifier
  dc = RandomForestClassifier(random_state=seed, n_estimators = 10)
  dc.fit(domain_train[var_list], domain_train["domain"])
  test1["domain proba"] = dc.predict_proba(test1[var_list])[:,1]
  if test1["domain proba"].max()==1: # predicted proba will generate infinite weights
      test1.loc[test1["domain proba"]==1,"domain proba"] = 0.99
  # predict domain
  domain_test["domain proba"] = dc.predict_proba(domain_test[var_list])[:,1]
  test2["domain proba"] = dc.predict_proba(test2[var_list])[:,1]
  # compute AUROC
  roc = roc_auc_score(domain_test["domain"], domain_test["domain proba"])
  auroc_list.append(roc)

  ## Second fit
  # fit domain classifier
  dc = RandomForestClassifier(random_state=seed, n_estimators = 10)
  dc.fit(domain_test[var_list], domain_test["domain"])
  train1["domain proba"] = dc.predict_proba(train1[var_list])[:,1]
  if train1["domain proba"].max()==1: # predicted proba will generate infinite weights
      train1.loc[train1["domain proba"]==1,"domain proba"] = 0.99
  # predict domain
  domain_train["domain proba"] = dc.predict_proba(domain_train[var_list])[:,1]
  train2["domain proba"] = dc.predict_proba(train2[var_list])[:,1]
  # compute AUROC
  roc = roc_auc_score(domain_train["domain"], domain_train["domain proba"])
  auroc_list.append(roc)

  # mean AUROC
  mean_auroc = np.mean(auroc_list)
  
  # concatenate
  source_with_weights = pd.concat((train1, test1)).sort_index()
  target_with_domain_proba = pd.concat((train2, test2)).sort_index()

  # compute weights on source data (used in weighted methods)
  source_with_weights["weights"] = source_with_weights["domain proba"]/(1-source_with_weights["domain proba"])
  source_with_weights["normalized weights"] = source_with_weights["weights"]/(source_with_weights["weights"].sum())

  return source_with_weights, target_with_domain_proba, mean_auroc
  
# Conformal prediction (Sadinle et al., 2019) + (Tibshirani et al., 2019)
def conformal_prediction(validation_softmax_data, validation_label_data, target_data, softmax_var, target_coverage, source_weights):
    """
    Compute the mean interval width by leveraging weighted conformal prediction (Tibshirani et al., 2019).
    The conformal score is 1 minus the probability of true class (Sadinle et al., 2019).

    Args:
    validation_softmax_data 
    validation_label_data 
    target_data 
    softmax_var
    target_coverage (float)
    source_weights
    """

    validation_softmax = validation_softmax_data.copy().values
    validation_label = validation_label_data.copy().values
    target = target_data.copy()
    target_softmax = target_data[softmax_var].copy().values
    weights = source_weights.copy().values
    
    # compute conformal score s (1-probability of true class) on the validation dataset
    s = 1 - validation_softmax[np.arange(len(validation_softmax)),validation_label]

    # sort conformity scores and weights
    ind_sort = np.argsort(s)
    sorted_s = s[ind_sort]
    sorted_w = weights[ind_sort]

    # compute the corrected quantile of conformal score
    Sn = np.cumsum(sorted_w)
    Pn = (Sn-0.5*sorted_w)/Sn[-1]
    conformity_quantile = np.interp(target_coverage * (len(s)+1)/len(s) , Pn, sorted_s) # corrected quantile (1-alpha) (n+1)/n quantile

    # compute interval widths
    target["interval width"] = ((1-target_softmax)<=conformity_quantile).sum(axis=1)
    mean_interval_width = target["interval width"].mean()

    return target, mean_interval_width 
    
# Error Classifier
def error_classifier(source_data, target_data, var_list, softmax_var, source_weights, algo_type, seed):
  """
  Fit the Error Classifier with a defined algorithm.

  Args:
  source_data (pandas dataframe): Source dataset
  target_data (pandas dataframe): Target dataset
  var_list (list of str): list of features names
  softmax_var (list of str): list of variable names for predicted probability
  weights (numpy array):
  algo_type (string): type of sklearn algorithm, random forest ("rf")
  seed (int): seed
  """

  source = source_data.copy()
  target = target_data.copy()
  weights = source_weights.copy().values

  # add error variable
  source["error"] = np.where(np.argmax(source[softmax_var], axis=1) != source['y'], 1, 0)

  # fit classification algorithm
  if algo_type =="rf":
      ec = RandomForestClassifier(random_state=seed, n_estimators = 100)
  ec.fit(source[var_list], source["error"], sample_weight = weights)

  # error prediction
  target["pred error"] = ec.predict_proba(target[var_list])[:,1]
  mean_error_rate = target["pred error"].mean()

  return ec, target, mean_error_rate
  
# Deep Nearest Neighbors (Sun et al., ICML 2022)
def deep_nearest_neighbors(k_neighbors, src, tgt, var_list):
    """
    We compute the average distance for the k-th neighbors (distance based on L2 norm).
    """
    source  = src.copy()
    target = tgt.copy()

    nbrs = NearestNeighbors(n_neighbors=k_neighbors)
    # normalize with L2 norm
    X_source = normalize(source[var_list])
    # fit on source data
    nbrs.fit(X_source)
    # compute distances on target data versus source data
    X_target = normalize(target[var_list])
    distances, indices = nbrs.kneighbors(X_target)
    # distances to k-th neighbor
    target_distances = distances[:,k_neighbors-1]
    target["target distances"] = target_distances
    # compute mean distance for the k-th neighbor
    mean_distance = target_distances.mean()

    return target, mean_distance
    
# Energy score (Liu et al., NeurIPS 2020)
def energy_score(lgts, T):
    """
    We estimate the energy score based on the logits.
    """
    logits = lgts.clone()

    energy_scores = -T*torch.logsumexp(logits/T, dim=1)
    mean_score = energy_scores.mean().item()
    return energy_scores, mean_score

# True Class Probability (Corbiere et al., NeurIPS 2019)
def true_class_probability(src_cls, src_softmax, src_y, tgt_y, seed, device):
    """
    We estimate one minus the true Class Probability (Corbiere et al. 2019).
    """
    source_cls = src_cls.clone()
    source_softmax = src_softmax.clone()
    source_y = src_y.clone()
    target_cls = tgt_y.clone()

    # compute the TCP on Source data
    indices = source_y.long()
    source_tcp = source_softmax.gather(1, indices.view(-1,1)).flatten()

    # create data loaders for Source: regression task to predict the TCP (true class probability)
    source_dataset = TensorDataset(source_cls, source_tcp)
    source_loader = DataLoader(source_dataset, sampler = RandomSampler(source_dataset), batch_size = 32)

    # create regression model
    nn_model = define_nn_model(init_seed = seed, input_shape = source_cls.shape[1], output_shape = 1)

    # training settings
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(nn_model.parameters())

    # training
    nn_model = train_nn_model(epochs = 10, seed = seed, loader = source_loader, model = nn_model, criterion = criterion, optimizer = optimizer, verbose = False, device = device)

    # prediction
    nn_model.eval()
    tcp_preds = nn_model(target_cls).flatten().cpu().detach().numpy()
    tcp_mean = np.mean(tcp_preds)

    return 1-tcp_preds, 1-tcp_mean

# Deep ensembles (Lakshminarayanan et al., NIPS 2017)
# Train M neural networks
def deep_ensembles(src_cls, src_y, tgt_cls, M, output_shape, seed, device):
    """
    We estimate the entropy with deep ensembles.
    """
    source_cls = src_cls.clone()
    source_y = src_y.clone()
    target_cls = tgt_cls.clone()

    nn_model_list = []
    for seed in range(M):
        # create model
        nn_model = define_nn_model(init_seed = seed, input_shape = source_cls.shape[1], output_shape = output_shape)

        # create data loaders for Source
        source_dataset = TensorDataset(source_cls, source_y)
        source_loader = DataLoader(source_dataset, sampler = RandomSampler(source_dataset), batch_size = 32)

        # training settings
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(nn_model.parameters())

        # training
        nn_model = train_nn_model(epochs = 10, seed = seed, loader = source_loader, model = nn_model, criterion = criterion, optimizer = optimizer, verbose = False, device = device)
        nn_model_list.append(nn_model)

    for m in range(M):

        # predict with each neural network
        nn_model_list[m].eval()
        pred = nn_model_list[m](target_cls)

        # compute softmax probabilities
        p_hat = F.softmax(pred, dim=1)

        # sum logits
        if m == 0:
          softmax_list_sum = p_hat
        else:
          softmax_list_sum += p_hat

    # compute average over iterations
    softmax_list_avg = softmax_list_sum/M

    # compute Shannon entropy
    shannon_entropy = entropy(softmax_list_avg.cpu().detach().numpy(), base=2, axis=1)
    mean_entropy = shannon_entropy.mean()

    return nn_model_list, shannon_entropy, mean_entropy
    
# Projection norm (Yu et al., ICML 2022)
def projection_norm(src_cls, src_y, tgt_cls, output_shape, seed, device):
    """
    Projection norm (Yu et al., ICML 2022)
    """
    source_cls = src_cls.clone()
    source_y = src_y.clone()
    target_cls = tgt_cls.clone()


    # define neural networks with same initialization
    nn_model1 = define_nn_model(init_seed = seed, input_shape = source_cls.shape[1], output_shape = output_shape)
    nn_model2 = define_nn_model(init_seed = seed, input_shape = source_cls.shape[1], output_shape = output_shape)

    #create data loaders for Source
    source_dataset = TensorDataset(source_cls, source_y)
    source_loader = DataLoader(source_dataset, sampler = RandomSampler(source_dataset), batch_size = 32)

    # train first model on Source data
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(nn_model1.parameters())
    nn_model1 = train_nn_model(epochs = 10, seed = seed, loader = source_loader, model = nn_model1, criterion = criterion, optimizer = optimizer, verbose = False, device = device)

    # predict with first model on Target data
    logits = nn_model1(target_cls)
    preds = torch.argmax(logits, dim = 1).to(device)

    # train second model on Target data with first model's pseudo-labels
    target_dataset = TensorDataset(target_cls, preds)
    target_loader = DataLoader(target_dataset, sampler = RandomSampler(target_dataset), batch_size = 32)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(nn_model2.parameters())
    nn_model2 = train_nn_model(epochs = 10, seed = seed, loader = target_loader, model = nn_model2, criterion = criterion, optimizer = optimizer, verbose = False, device = device)

    # compute the distance between the parameters
    params1 = list(nn_model1.parameters())
    params2 = list(nn_model2.parameters())

    diff = 0
    for i in range(len(params1)):
        param1 = params1[i]
        param2 = params2[i]
        diff += (torch.norm(param1.flatten() - param2.flatten()) ** 2).item()

    return diff
