import torch
import numpy as np

from adapt.instance_based import *

def weights_estimation(estimator, estimator_algorithm, offline_rep, offline_label, online_rep):
    est_loss = 0.
    if estimator_algorithm == 'Accous':
        offline_rep = torch.Tensor.cpu(offline_rep).detach().float()
        online_rep = torch.Tensor.cpu(online_rep).detach()
        weights, est_loss = estimator.estimate(offline_rep, online_rep)
    elif estimator_algorithm == 'OLR':
        offline_rep = torch.Tensor.cpu(offline_rep).detach().float()
        online_rep = torch.Tensor.cpu(online_rep).detach()
        weights = estimator.estimate(offline_rep, online_rep)
    elif estimator_algorithm == 'KMM':
        adapt_model = KMM(Xt=np.float64(torch.Tensor.cpu(online_rep).detach().numpy()), verbose=0)
        adapt_model.fit(np.float64(torch.Tensor.cpu(offline_rep).detach().numpy()), offline_label)
        weights = torch.tensor(adapt_model.predict_weights())
    elif estimator_algorithm == 'KLIEP':
        adapt_model = KLIEP(Xt=np.float64(torch.Tensor.cpu(online_rep).detach().numpy()), verbose=0)
        adapt_model.fit(np.float64(torch.Tensor.cpu(offline_rep).detach().numpy()), offline_label)
        weights = torch.tensor(adapt_model.predict_weights())
    elif estimator_algorithm == 'RULSIF':
        adapt_model = RULSIF(Xt=np.float64(torch.Tensor.cpu(online_rep).detach().numpy()), verbose=0)
        adapt_model.fit(np.float64(torch.Tensor.cpu(offline_rep).detach().numpy()), offline_label)
        weights = torch.tensor(adapt_model.predict_weights())
    elif estimator_algorithm == 'ULSIF':
        adapt_model = ULSIF(Xt=np.float64(torch.Tensor.cpu(online_rep).detach().numpy()), verbose=0)
        adapt_model.fit(np.float64(torch.Tensor.cpu(offline_rep).detach().numpy()), offline_label)
        weights = torch.tensor(adapt_model.predict_weights())
    elif estimator_algorithm == 'ONES':
        weights = torch.ones(offline_rep.shape[0])
    else:
        raise NotImplementedError

    return weights, est_loss