import torch
from monotonenorm.monotonicnetworks import SigmaNet, GroupSort, direct_norm
from tqdm import tqdm
from loaders.blog_loader import load_data, mono_list

from BLNN import PICNN, PICNN_multiclass

torch.set_default_dtype(torch.float64)


def get_comp(list, length):
    id = []
    for i in range (length):
        if i not in list:
            id.append(i)
    return torch.tensor(id)

def run_exp(
    Xtr, Ytr, Xts, Yts, monotone_constraints,
    max_lr, expwidth, depth, Lip, batchsize, seed
):
    device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(seed)
    Xtrt = torch.tensor(Xtr, dtype=torch.float64).to(device)
    Ytrt = torch.tensor(Ytr, dtype=torch.float64).view(-1, 1).to(device)
    Xtst = torch.tensor(Xts, dtype=torch.float64).to(device)
    Ytst = torch.tensor(Yts, dtype=torch.float64).view(-1, 1).to(device)
    idx = torch.arange(len(Xtrt))
    c_mono_list = get_comp(mono_list, len(Xtrt[0]))
    print("mono",mono_list)
    print("cmono", c_mono_list)
    # normalize training data
    mean = Xtrt.mean(0)
    std = Xtrt.std(0)
    Xtrt = (Xtrt - mean) / std
    Xtst = (Xtst - mean) / std

    dataloader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(Xtrt, Ytrt, idx), batch_size=batchsize, shuffle=True
    )

    per_layer_lip = Lip ** (1 / depth)
    width = 2 ** expwidth

    model  = PICNN_multiclass(len(Xtrt[0])-1,1,3,2, 1.0,0, 8).to(device)


    optimizer = torch.optim.Adam(model.parameters(), lr=max_lr)
    EPOCHS = 1000

    print("params:", sum(p.numel() for p in model.parameters()))
    bar = tqdm(range(EPOCHS))
    best_rmse = 1
    for _ in bar:
        model.train()
        for Xi, yi, idx in dataloader:
            y_pred = model(Xi,[50, 51, 52, 53, 55, 56, 57, 58], idx)
            losstr = torch.nn.functional.mse_loss(torch.nn.Sigmoid()(y_pred), yi)
            optimizer.zero_grad()
            losstr.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            y_predts = model(Xtst,[50, 51, 52, 53, 55, 56, 57, 58])
            lossts = torch.nn.functional.mse_loss(torch.nn.Sigmoid()(y_predts), Ytst)
            tsrmse = lossts.item() ** 0.5
            trrmse = losstr.item() ** 0.5
            best_rmse = min(best_rmse, tsrmse)
            bar.set_description(
                f"train rmse: {trrmse:.5f} test rmse: {tsrmse:.5f}, best: {best_rmse:.5f}, lr: {optimizer.param_groups[0]['lr']:.5f}"
            )
    return best_rmse
