import torch
import nltk
from roberta_model_loader import RobertaModelLoader
from feature_ref_loader import feature_ref_loader
from meta_train import net
from regression_model_loader import regression_model
from MMD import MMD_3_Sample_Test
from utils import FeatureExtractor, HWT, MGT, config

device = "cuda:3"
class RelativeTester:
    def __init__(self):
        print("Relative Tester init")
        self.feature_extractor = FeatureExtractor(RobertaModelLoader(), net)
        self.feature_hwt_ref = feature_ref_loader(config["feature_ref_HWT"])
        self.feature_mgt_ref = feature_ref_loader(config["feature_ref_MGT"])

    def sents_split(self, text):
        # nltk.download("punkt", quiet=True)
        # nltk.download("punkt_tab", quiet=True)
        sents = nltk.sent_tokenize(text)
        return [sent for sent in sents if 5 < len(sent.split())]

    def test(self, input_text, threshold=0.2, round=20):

        print("Relative Tester test")
        # Split the input text
        sents = self.sents_split(input_text)
        print("DEBUG: sents:", len(sents))
        # Extract features
        feature_for_sents = self.feature_extractor.process_sents(sents, False)
        if len(feature_for_sents) <= 1:
            # print("DEBUG: tooshort")
            return "Too short to test! Please input more than 2 sentences."
        # Cutoff the features
        min_len = min(
            len(feature_for_sents),
            len(self.feature_hwt_ref),
            len(self.feature_mgt_ref),
        )
        # Calculate MMD
        h_u_list = []
        p_value_list = []
        t_list = []
        net.net.to(device)
        for i in range(round):
            feature_for_sents_sample = feature_for_sents[
                torch.randperm(len(feature_for_sents))[:min_len]
            ]
            feature_hwt_ref_sample = self.feature_hwt_ref[
                torch.randperm(len(self.feature_hwt_ref))[:min_len]
            ]
            feature_mgt_ref_sample = self.feature_mgt_ref[
                torch.randperm(len(self.feature_mgt_ref))[:min_len]
            ]
            # print(">>> net.device:", next(net.net.parameters()).device)
            # print(">>> features.device:", feature_for_sents_sample.device)
            h_u, p_value, t, *rest = MMD_3_Sample_Test(
                net.net(feature_for_sents_sample.to(device)),
                net.net(feature_hwt_ref_sample.to(device)),
                net.net(feature_mgt_ref_sample.to(device)),
                feature_for_sents_sample.view(feature_for_sents_sample.shape[0], -1).to(device),
                feature_hwt_ref_sample.view(feature_hwt_ref_sample.shape[0], -1).to(device),
                feature_mgt_ref_sample.view(feature_mgt_ref_sample.shape[0], -1).to(device),
                net.sigma.to(device),
                net.sigma0_u.to(device),
                net.ep.to(device),
                0.05,
            )

            h_u_list.append(h_u)
            p_value_list.append(p_value)
            t_list.append(t)

        power = sum(h_u_list) / len(h_u_list)
        print("DEBUG: power:", power)
        print("DEBUG: power list:", h_u_list)
        # Return the result
        return power
