import torch
from pathlib import Path
import logging

txt_logger = logging.getLogger("sfda_reg")

def dropout_sampling(net, dataloader, sampling_num, do_p):
    dataset_len = len(dataloader.dataset)

    y_dropout_bank = torch.zeros(dataset_len, net.reg_num, sampling_num).cuda()
    y_true_bank = torch.zeros(dataset_len, net.reg_num).cuda()
    y_pred_bank = torch.zeros(dataset_len, net.reg_num).cuda()


    net.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, y_true, idx = batch
            x = x.cuda()
            y_true = y_true.float().cuda()
            # print(y_true.size())
            do_y_pred = []
            for _ in range(sampling_num):
                feature = net.feature(x)
                y_pred = net.predict_from_feature(feature, do_p, True).view(
                    -1, net.reg_num)
                do_y_pred.append(y_pred.detach().clone())
            y_pred = net.predict_from_feature(feature)
            y_pred_bank[idx] = y_pred.detach().clone().view(-1, net.reg_num)

            batch_do_y_pred = torch.stack(
                do_y_pred, dim=0).permute(
                    1, 2, 0)  # [sampling_num, batch size] -> [batch, 1, sampling_num]
            y_dropout_bank[idx] = batch_do_y_pred.view(-1, net.reg_num, sampling_num)

            y_true_bank[idx] = y_true.view(-1, net.reg_num)

    if net.reg_num == 1:
        y_dropout_bank = y_dropout_bank.squeeze(1)
    y_bank_min, _ = torch.min(y_dropout_bank, dim=-1)
    y_bank_max, _ = torch.max(y_dropout_bank, dim=-1)

    do_re = {
        "y_true": y_true_bank.cpu(),
        "y_pred": y_pred_bank.cpu(),
        "y_pred_sample": y_dropout_bank.cpu(),
        "y_bank_min": y_bank_min.cpu(),
        "y_bank_max": y_bank_max.cpu(),
    }

    return do_re



def get_sampling(regressor_do, target_dl, config_do):
    
    sampling_num = config_do["sampling_num"] 
    do_p = config_do["do_p"] 
    txt_logger.info(f"Dropout Result Generating, sampling number is [{sampling_num}], dropout rate is [{do_p}].")
    sample_dict = dropout_sampling(
        regressor_do,
        target_dl,
        sampling_num,
        do_p,
    )
        
    return sample_dict