import numpy as np
import torch
import time
import sys
import builtins
import pandas as pd
import os
import matplotlib.pyplot as plt

sys.path.append(builtins.ROOT_PATH)
from exp.dataloader import load_data, check_device, log
from baselines.IT.utils_IT import *
from models import load_model

def IT(dataset, N_more, N_less, rs, check, is_balanced, n_test, n_per, alpha, is_selection=True):

    device = check_device()
    np.random.seed(rs)
    torch.manual_seed(rs)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if builtins.MODEL_ARCH is not None:
        model = load_model(builtins.MODEL_ARCH)
        model.to(device)
        model.eval()

    H_IT = np.zeros(n_test)
    P_IT = np.zeros(n_test)
    test_time = 0
    for k in range(n_test):
        start_time = time.time()
        # test_idx = np.random.choice(len(Y), size=N_less, replace=False)
        # Y_test = Y[test_idx]
        X, Y_test, _ = load_data(dataset, N_more, N_less, rs*1000+k, check, need_labels=True)
        X = X.to(device, dtype=torch.float32)
        Y_test = Y_test.to(device, dtype=torch.float32)
        if builtins.MODEL_ARCH is not None:
            with torch.no_grad():
                X = model(X)
                Y_test = model(Y_test)
        n = len(X)
        perm = torch.randperm(n, device=device)
        n_train, n_calib = int(n * 0.4), int(n * 0.1)
        X_train = X[perm[:n_train]]
        X_calib = X[perm[n_train:n_train+n_calib]]
        X_hold = X[perm[n_train+n_calib:]]
        # lott = LOTT(n_permutations=300)
        if is_selection:
            lott = LOTTWithSelection(alpha=alpha, n_permutations=500, selection_method='precision_weight', verbose=False)
        else:
            lott = LOTT(alpha=alpha, n_permutations=500)
        lott.fit(X_train, X_calib, X_hold)

        if not is_balanced:
            results = lott.test(Y_test)
            H_IT[k] = int(results['reject'])
            P_IT[k] = results['p_value']
        test_time += time.time() - start_time
    
    # print(np.sort(P_IT))
    if builtins.IT_TIME_LOG == 0:
        log("IT avg test time: ", test_time/n_test, "s")
        builtins.IT_TIME_LOG += 1
    torch.cuda.empty_cache()
    return H_IT, None, None