import torch
from tqdm import tqdm
from monotonenorm.monotonicnetworks import GroupSort, direct_norm, SigmaNet
import numpy as np
from sklearn.metrics import accuracy_score

from BLNN import PICNN, PICNN_multiclass
from loaders.loan_loader import load_data, mono_list

torch.set_default_dtype(torch.float64)

def run_exp(Xtr, Ytr, Xts, Yts, monotone_constraints, width, depth, seed):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(seed)
    Xtrt = torch.tensor(Xtr, dtype=torch.float64).to(device)
    Ytrt = torch.tensor(Ytr, dtype=torch.float64).view(-1, 1).to(device)
    Xtst = torch.tensor(Xts, dtype=torch.float64).to(device)
    Ytst = torch.tensor(Yts, dtype=torch.float64).view(-1, 1).to(device)
    idx = torch.arange(len(Xtrt))
    model  = PICNN_multiclass(len(Xtrt[0])-1,1,10,2, 0.5,1, 5).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    print("params:", sum(p.numel() for p in model.parameters()))


    mean = Xtrt.mean(0)
    std = Xtrt.std(0)
    Xtrt = (Xtrt - mean) / std
    Xtst = (Xtst - mean) / std
    dataloader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(Xtrt, Ytrt, idx), batch_size=int(2 ** 9), shuffle=True
    )
    max_acc = 0

    bar = tqdm(range(200))
    for i in bar:
        jj = 0
        for Xi, yi, idx in dataloader:
            jj+=1
            y_pred = model(Xi,[0,1,2,3,4], idx)
            losstr = torch.nn.functional.binary_cross_entropy(torch.nn.Sigmoid()(y_pred), yi)
            optimizer.zero_grad()
            losstr.backward()
            optimizer.step()
        with torch.no_grad():
            #print("eval")
            y_predts = model(Xtst,[0,1,2,3,4])

            lossts = torch.nn.functional.binary_cross_entropy(torch.nn.Sigmoid()(y_predts), Ytst)
            acc = 0
            for i in np.linspace(0, 1, 100):
                acc = max(
                    acc, accuracy_score(Ytst.cpu().numpy(), y_predts.cpu().numpy() > i),
                )

            max_acc = max(max_acc, acc)
            bar.set_description(
                f"Loss: {losstr.item():.4f} {lossts.item():.4f}, acc: {acc.item():.4f}, max_acc: {max_acc:.4f}"
            )
            #print("finished eval")

    return max_acc
