from sklearn.metrics import roc_auc_score
from typing import Callable
import torch
from torch_geometric.loader import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import argparse
import time
import numpy as np
from datasets import loaddataset
from norm import NormMomentumScheduler, basenormdict
from SizeAlignedLoader import batch2dense
from torch import autograd
import numpy as np
from DropoutScheduler import dropoutscheduler
### importing OGB
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from PiOModel import PiOModel
import os
from CosAnnDecay import CosineAnnealingWarmRestarts

# torch.autograd.set_detect_anomaly(True)

# torch.set_float32_matmul_precision('high')

def get_criterion(task, args):
    if task == "smoothl1reg":
        return torch.nn.SmoothL1Loss(reduction="none", beta=args.lossparam)
    else:
        criterion_dict = {
            "bincls": torch.nn.BCEWithLogitsLoss(reduction="none"),
            "cls": torch.nn.CrossEntropyLoss(reduction="none"),
            "reg": torch.nn.MSELoss(reduction="none"),
            "l1reg": torch.nn.L1Loss(reduction="none"),
        }
        return criterion_dict[task]


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def train(criterion: Callable,
          model: PiOModel,
          device: torch.device,
          loader: DataLoader,
          optimizer: optim.Optimizer,
          task_type: str,
          ampscaler: torch.cuda.amp.GradScaler = None,
          advloss: bool=False,
          scheduler: optim.lr_scheduler.LinearLR = None,
          gradclipnorm: float=1):
    model.train()
    losss = []
    for batch in loader:
        batch = batch.to(device, non_blocking=True)
        optimizer.zero_grad()
        with torch.autocast(device_type='cuda', enabled=ampscaler is not None):
            A, X, nodemask, _ = batch2dense(batch)
            finalpred = model(A, X, nodemask)
            finalpred = finalpred
            y = batch.y
            if task_type != "cls":
                y = y.to(torch.float)
            if advloss:
                value_loss = criterion(finalpred, y).flatten()
                weight = torch.softmax(value_loss.detach(), dim=-1)
                value_loss = torch.inner(value_loss, weight)
            else:
                value_loss = torch.mean(criterion(finalpred, y))
        if ampscaler is not None:
            ampscaler.scale(value_loss).backward()
            ampscaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), gradclipnorm)
            ampscaler.step(optimizer)
            ampscaler.update()
        else:
            value_loss.backward()
            totalnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradclipnorm)
            # print(totalnorm)
            optimizer.step()
        scheduler.step()
        losss.append(value_loss)
        #if torch.any(torch.isnan(value_loss)):
        #    return float("nan")
    loss = np.average([_.item() for _ in losss])
    return loss

@torch.no_grad()
def eval(model: PiOModel,
         device: torch.device,
         loader: DataLoader,
         evaluator,
         amp: bool = False):
    model.eval()
    ty = loader.dataset.y
    ylen = ty.shape[0]
    y_true = torch.zeros((ylen), dtype=ty.dtype)
    y_pred = torch.zeros((ylen, model.num_tasks), device=device)
    step = 0
    for batch in loader:
        steplen = batch.y.shape[0]
        y_true[step:step + steplen] = batch.y
        batch = batch.to(device, non_blocking=True)
        with torch.autocast(device_type='cuda', enabled=amp):
            A, X, nodemask, _ = batch2dense(batch)
            tpred = model(A, X, nodemask)
        y_pred[step:step + steplen] = tpred
        step += steplen
    assert step == y_true.shape[0]
    y_pred = y_pred.cpu()
    return evaluator(y_pred, y_true)


def parserarg():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="ogbg-molhiv")
    parser.add_argument('--repeat', type=int, default=10)
    parser.add_argument('--num_workers', type=int, default=0)

    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--simmodel", action="store_true")
    parser.add_argument("--ssimmodel", action="store_true")
    parser.add_argument("--amp", action="store_true")
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--testbatch_size', type=int, default=1024)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--wd', type=float, default=0)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--minlr', type=float, default=0.0)
    parser.add_argument('--K', type=float, default=0.0)
    parser.add_argument('--K2', type=float, default=0.0)

    parser.add_argument('--gradclipnorm', type=float, default=1.0)
    parser.add_argument('--decompnoise', type=float, default=0)
    parser.add_argument('--noiseK', type=float, default=1)

    parser.add_argument('--seedoffset', type=int, default=0)

    parser.add_argument('--warmstart', type=int, default=0)
    parser.add_argument('--conststep', type=int, default=100)
    parser.add_argument('--cosstep', type=int, default=40)
    parser.add_argument("--permutedata", action="store_true")

    parser.add_argument('--dppreepoch', type=int, default=-1)
    parser.add_argument('--dp', type=float, default=0.0)
    parser.add_argument('--eldp', type=float, default=0.0)

    parser.add_argument("--act", type=str, default="silu")

    parser.add_argument('--lossparam', type=float, default=0.05)
    parser.add_argument('--advloss', action="store_true")

    parser.add_argument("--normA", action="store_true")

    parser.add_argument('--embdp', type=float, default=0.0)
    parser.add_argument("--embbn", action="store_true")
    parser.add_argument("--embln", action="store_true")
    parser.add_argument("--emborthoinit", action="store_true")
    parser.add_argument('--featdim', type=int, default=64)
    parser.add_argument("--degreeemb", action="store_true")
    parser.add_argument('--hiddim', type=int, default=64)
    parser.add_argument('--caldim', type=int, default=64)

    parser.add_argument('--input_encoder', type=str, default="vanilla")
    parser.add_argument('--laplacian', action="store_true")
    
    parser.add_argument("--elres", action="store_true")
    parser.add_argument("--tgres", action="store_true")
    parser.add_argument("--usesvmix", action="store_true")
    parser.add_argument("--usetg", action="store_true")
    parser.add_argument("--vmean", action="store_true")
    parser.add_argument("--vnorm", action="store_true")
    parser.add_argument("--elvmean", action="store_true")
    parser.add_argument("--elvnorm", action="store_true")
    parser.add_argument("--tgvmean", action="store_true")
    parser.add_argument("--tgvnorm", action="store_true")
    parser.add_argument("--snorm", action="store_true")

    parser.add_argument("--gsizenorm", type=float, default=1)
    parser.add_argument("--lsizenorm", type=float, default=1)
    
    parser.add_argument("--l_encoder", type=str, default="deepset")
    parser.add_argument("--l_layers", type=int, default=1)
    parser.add_argument("--l_nhead", type=int, default=1)
    parser.add_argument("--l_dffn", type=int, default=32)
    parser.add_argument("--l_normfirst", action="store_true")
    parser.add_argument("--l_combine", type=str, default="mul")
    parser.add_argument('--l_aggr', type=str, default="mean")
    parser.add_argument("--l_res", action="store_true")
    parser.add_argument("--l_mlptailact1", action="store_true")
    parser.add_argument("--l_mlplayers1", type=int, default=1)
    parser.add_argument("--l_mlpnorm1", type=str, default="ln")
    parser.add_argument("--l_mlptailact2", action="store_true")
    parser.add_argument("--l_mlplayers2", type=int, default=1)
    parser.add_argument("--l_mlpnorm2", type=str, default="ln")

    parser.add_argument("--num_layers", type=int, default=3)

    parser.add_argument("--sv_uselinv", action="store_true")
    parser.add_argument("--sv_tailact", action="store_true")
    parser.add_argument("--sv_res", action="store_true")
    parser.add_argument("--sv_numlayer", type=int, default=1)
    parser.add_argument("--sv_norm", type=str, default="ln")

    parser.add_argument("--gaggr_layers", type=int, default=1)
    parser.add_argument("--gaggr_combine", type=str, default="mul")
    parser.add_argument("--gaggr_aggr", type=str, default="mean")
    parser.add_argument("--gaggr_res", action="store_true")
    parser.add_argument("--gaggr_pool", type=str, default="mean")
    parser.add_argument("--gaggr_mlptailact1", action="store_true")
    parser.add_argument("--gaggr_mlplayers1", type=int, default=1)
    parser.add_argument("--gaggr_mlpnorm1", type=str, default="ln")
    parser.add_argument("--gaggr_mlptailact2", action="store_true")
    parser.add_argument("--gaggr_mlplayers2", type=int, default=1)
    parser.add_argument("--gaggr_mlpnorm2", type=str, default="ln")
    parser.add_argument("--gaggr_vln", action="store_true")
    parser.add_argument("--gaggr_reduce", action="store_true")

    parser.add_argument("--tprod_tailact", action="store_true")
    parser.add_argument("--tprod_numlayer", type=int, default=1)
    parser.add_argument("--tprod_norm", type=str, default="ln")
    parser.add_argument("--tprod_vln", action="store_true")

    parser.add_argument("--el_uselinv", action="store_true")
    parser.add_argument("--el_uselins", action="store_true")
    parser.add_argument("--el_tailact", action="store_true")
    parser.add_argument("--el_numlayer", type=int, default=1)
    parser.add_argument("--el_norm", type=str, default="ln")
    parser.add_argument("--el_uses", action="store_true")

    parser.add_argument("--conv_uselinv", action="store_true")
    parser.add_argument("--conv_tailact", action="store_true")
    parser.add_argument("--conv_numlayer", type=int, default=1)
    parser.add_argument("--conv_norm", type=str, default="ln")

    parser.add_argument("--predlin_numlayer", type=int, default=2)
    parser.add_argument("--predlin_norm",
                        choices=["bn", "ln", "in", "none"],
                        default="none")

    parser.add_argument("--sizenormU", action="store_true")
    parser.add_argument("--lexp",
                        type=str,
                        choices=["gauss", "gg", "mlp", "sin"],
                        default="mlp")
    parser.add_argument("--lexp_layer", type=int, default=1)
    parser.add_argument("--lexp_norm",
                        type=str,
                        choices=["bn", "ln", "in", "none"],
                        default="none")
    parser.add_argument("--outln", action="store_true")
    
    parser.add_argument("--pool", type=str, default="sum")

    parser.add_argument("--Tm", type=float, default=1)

    parser.add_argument('--save', type=str, default=None)
    parser.add_argument('--load', type=str, default=None)

    args = parser.parse_args()
    print(args)
    return args


def buildModel(args, num_tasks, device, dataset, needcompile: bool=True):
    xembdims = []
    tx = dataset.x
    if tx is None:
        xembdims = [None]
    elif tx.dtype == torch.long:
        assert tx.dim() == 2
        assert torch.all(tx != 0)
        xembdims = (torch.max(tx, dim=0)[0] + 1).tolist()
        print(xembdims)
    # assert torch.all(dataset.edge_attr != 0)
    from utils import act_dict
    kwargs = {
        "elres": args.elres,
        "tgres": args.tgres,
        "usesvmix": args.usesvmix,
        "usetg": args.usetg,
        "gsizenorm": args.gsizenorm,
        "lsizenorm": args.lsizenorm,
        "vmean": args.vmean,
        "vnorm": args.vnorm,
        "elvmean": args.elvmean,
        "elvnorm": args.elvnorm,
        "tgvmean": args.elvmean,
        "tgvnorm": args.elvnorm,
        "snorm": args.snorm,
        "basic": {
            "dropout": args.dp,
            "activation": act_dict[args.act],
        },
        "inputencoder": {
            "laplacian": args.laplacian,
            "permutedata":args.permutedata,
            "xemb": {
                "orthoinit": args.emborthoinit,
                "bn": args.embbn,
                "ln": args.embln,
                "dropout": args.embdp,
                "lastzeropad": len(xembdims)
            },
            "lambdaemb": {
                "numlayer": args.lexp_layer,
                "norm": args.lexp_norm,
            },
            "sizenormU": args.sizenormU,
            "normA": args.normA,
            "degreeemb": args.degreeemb,
            "lexp": args.lexp,
            "xembdims": xembdims,
            "decompnoise": args.decompnoise
        },
        "l_model": {
            "numlayers": args.l_layers,
            "nhead": args.l_nhead,
            "dffn": args.l_dffn,
            "norm_first": args.l_normfirst,
            "aggr": args.l_aggr,
            "combine": args.l_combine,
            "res": args.l_res,
            "mlpargs1": {
                "numlayer": args.l_mlplayers1,
                "norm": args.l_mlpnorm1,
                "tailact": args.l_mlptailact1,
            },
            "mlpargs2": {
                "numlayer": args.l_mlplayers2,
                "norm": args.l_mlpnorm2,
                "tailact": args.l_mlptailact2,
            }
        },
        "svmix": {
            "uselinv": args.sv_uselinv,
            "numlayer": args.sv_numlayer,
            "norm": args.sv_norm,
            "tailact": args.sv_tailact,
            "res": args.sv_res
        },
        "elproj": {
            "uselinv": args.el_uselinv,
            "uselins": args.el_uselins,
            "numlayer": args.el_numlayer if args.caldim==args.hiddim else max(1, args.el_numlayer),
            "norm": args.el_norm,
            "tailact": args.el_tailact,
            "uses": args.el_uses,
        },
        "conv": {
            "uselinv": args.conv_uselinv,
            "numlayer": args.conv_numlayer,
            "norm": args.conv_norm,
            "tailact": args.conv_tailact,
        },
        "gaggr": {
            "permlayer": {
                "numlayers": args.gaggr_layers,
                "aggr": args.gaggr_aggr,
                "combine": args.gaggr_combine,
                "res": args.gaggr_res,
                "mlpargs1": {
                    "numlayer": args.gaggr_mlplayers1,
                    "norm": args.gaggr_mlpnorm1,
                    "tailact": args.gaggr_mlptailact1,
                },
                "mlpargs2": {
                    "numlayer": args.gaggr_mlplayers2,
                    "norm": args.gaggr_mlpnorm2,
                    "tailact": args.gaggr_mlptailact2,
                },
                "pool": args.gaggr_pool
            },
            "vln": args.gaggr_vln,
            "isreduce": args.gaggr_reduce,
        },
        "tprod": {
            "mlp": {
                "numlayer": args.tprod_numlayer,
                "norm": args.tprod_norm,
                "tailact": True,
            },
            "vln": args.tprod_vln,
        },
        "predlin": {
            "numlayer": args.predlin_numlayer,
            "norm": args.predlin_norm,
            "tailact": False,
        },
        "outln": args.outln
    }
    kwargs["predlin"].update(kwargs["basic"])
    kwargs["svmix"].update(kwargs["basic"])
    kwargs["tprod"]["mlp"].update(kwargs["basic"])
    kwargs["l_model"]["mlpargs1"].update(kwargs["basic"])
    kwargs["l_model"]["mlpargs2"].update(kwargs["basic"])
    kwargs["gaggr"]["permlayer"]["mlpargs1"].update(kwargs["basic"])
    kwargs["gaggr"]["permlayer"]["mlpargs2"].update(kwargs["basic"])
    kwargs["inputencoder"]["basic"] = kwargs["basic"]
    kwargs["conv"].update(kwargs["basic"])
    kwargs["elproj"].update(kwargs["basic"])
    kwargs["dropout"]: args.eldp
    print("num_task", num_tasks)
    if args.nodetask:
        args.pool = "none"
    model = PiOModel(args.featdim, args.caldim, args.hiddim, num_tasks, args.l_encoder, args.num_layers, args.pool, **kwargs)
    model = model.to(device)
    print(model)
    if needcompile:
        return torch.compile(model)
    else:
        return model


def main():
    # Training settings
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = parserarg()
    ### automatic dataloading and splitting
    datasets, split, evaluator, task = loaddataset(args.dataset)
    args.nodetask = "node" in task
    task = task.removeprefix("node")
    print(split, task)
    outs = []
    set_seed(0)
    if split.startswith("fold"):
        trn_ratio, val_ratio, tst_ratio = int(split.split("-")[-3]), int(
            split.split("-")[-2]), int(split.split("-")[-1])
        num_fold = trn_ratio + val_ratio + tst_ratio
        trn_ratio /= num_fold
        val_ratio /= num_fold
        tst_ratio /= num_fold
        num_data = len(datasets[0])
        idx = torch.randperm(num_data)
        splitsize = num_data // num_fold
        idxs = [
            torch.cat((idx[splitsize * _:], idx[:splitsize * _]))
            for _ in range(num_fold)
        ]
        num_trn = int(trn_ratio * num_data)
        num_val = int(val_ratio * num_data)
    for rep in range(args.repeat):
        set_seed(rep+args.seedoffset)
        if "fixed" == split:
            trn_d, val_d, tst_d = datasets
        elif split.startswith("fold"):
            idx = idxs[rep]
            trn_idx, val_idx, tst_idx = idx[:num_trn], idx[
                num_trn:num_trn + num_val], idx[num_trn + num_val:]
            trn_d, val_d, tst_d = datasets[0][trn_idx], datasets[0][
                val_idx], datasets[0][tst_idx]
        else:
            datasets, split, evaluator, task = loaddataset(args.dataset)
            trn_d, val_d, tst_d = datasets
        print(len(trn_d), len(val_d), len(tst_d))
        train_loader = DataLoader(trn_d,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=args.num_workers)
        train_eval_loader = DataLoader(trn_d,
                                       batch_size=args.testbatch_size,
                                       shuffle=False,
                                       drop_last=False,
                                       num_workers=args.num_workers)
        valid_loader = DataLoader(val_d,
                                  batch_size=args.testbatch_size,
                                  shuffle=False,
                                  drop_last=False,
                                  num_workers=args.num_workers)
        test_loader = DataLoader(tst_d,
                                 batch_size=args.testbatch_size,
                                 shuffle=False,
                                 drop_last=False,
                                 num_workers=args.num_workers)
        print(f"split {len(trn_d)} {len(val_d)} {len(tst_d)}")
        model = buildModel(args, trn_d.num_tasks, device, trn_d)

        grouplambda = list(model.inputencoder.parameters()) + list(model.LambdaEncoder.parameters())
        grouppred = list(model.predlin.parameters())
        def filterfunc(p, pL):
            for p2 in pL:
                if p is p2:
                    return False
            return True
        groupconv = [p for p in model.parameters() if filterfunc(p, grouplambda+grouppred)]
        print(len(grouplambda), len(grouppred), len(groupconv), len(list(model.parameters())))
        optimizer = optim.AdamW(
            [{"params": grouplambda, "lr": args.lr, "weight_decay": args.wd},
            {"params": grouppred, "lr": args.lr, "weight_decay": args.wd},
            {"params": groupconv, "lr": args.lr, "weight_decay": args.wd}], 
            )
        if args.load is not None:
            print(f"mod/{args.load}.{rep}.pt")
            loadparams = torch.load(f"mod/{args.load}.{rep}.pt",
                                    map_location="cpu")
            print(model.load_state_dict(loadparams, strict=False))
            loadoptparams = torch.load(f"mod/{args.load}.opt.{rep}.pt",
                                    map_location="cpu")
            print(optimizer.load_state_dict(loadoptparams))
        '''
        schedulerwst = optim.lr_scheduler.LambdaLR(
            optimizer, lambda epoch: 0.99**(args.warmstart*len(train_loader) - epoch))
        schedulerdc = optim.lr_scheduler.LambdaLR(
            optimizer, lambda epoch: 1 / (1 + epoch *
                                          (args.K + args.K2 * epoch/len(train_loader))/len(train_loader)))
        
        scheduler = optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[schedulerwst, schedulerdc],
            milestones=[args.warmstart],
            verbose=True)
        '''
        scheduler0 = optim.lr_scheduler.LinearLR(optimizer, 1, 1, args.warmstart*len(train_loader))
        scheduler1 = CosineAnnealingWarmRestarts(optimizer, args.cosstep*len(train_loader), eta_min=args.minlr, T_mult=args.Tm, K=args.K)
        scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler0, scheduler1], milestones=[args.conststep*len(train_loader)])
        #scheduler = optim.lr_scheduler.CyclicLR(optimizer, 1e-5, args.lr, args.conststep*len(train_loader)//2, cycle_momentum=False, scale_fn=lambda x: 1/(1+x*args.warmstart*len(train_loader)*args.K))
        dpscheduler = dropoutscheduler(model, args.dp, args.epochs,
                                       args.dppreepoch)
        valid_curve = []
        test_curve = []
        train_curve = []

        ampscaler = torch.cuda.amp.GradScaler() if args.amp else None

        if args.debug:
            loss = debug(get_criterion(task,
                                        args), model, device, train_loader,
                            optimizer, task, ampscaler)
            exit()

        for epoch in range(1, args.epochs + 1):
            model.inputencoder.setnoiseratio(args.decompnoise * args.noiseK**epoch)
            t1 = time.time()
            loss = train(get_criterion(task,
                                       args), model, device, train_loader,
                         optimizer, task, ampscaler, args.advloss, scheduler, args.gradclipnorm)
            print(
                f"Epoch {epoch} train time : {time.time()-t1:.1f} loss: {loss:.2e}",
                flush=True)

            t1 = time.time()
            train_perf = 0.0 #eval(model, device, train_eval_loader, evaluator, args.amp)
            valid_perf = eval(model, device, valid_loader, evaluator, args.amp)
            test_perf = eval(model, device, test_loader, evaluator, args.amp)
            print(
                f" test time : {time.time()-t1:.1f} Train {train_perf} Validation {valid_perf} Test {test_perf}",
                flush=True)
            train_curve.append(loss)
            valid_curve.append(valid_perf)
            test_curve.append(test_perf)
            if np.isnan(test_perf):
                break
            if args.save is not None:
                torch.save(model.state_dict(),
                                   f"mod/{args.save}.{rep}.pt")
                torch.save(optimizer.state_dict(),
                                   f"mod/{args.save}.opt.{rep}.pt")
                '''
                if "cls" in task:
                    if valid_curve[-1] >= np.max(valid_curve):
                        torch.save(model.state_dict(),
                                   f"mod/{args.save}.{rep}.pt")
                else:
                    if valid_curve[-1] <= np.min(valid_curve):
                        torch.save(model.state_dict(),
                                   f"mod/{args.save}.{rep}.pt")
                '''
            # scheduler.step()
            dpscheduler.step()
        if 'cls' in task:
            best_val_epoch = np.argmax(
                np.array(valid_curve) + np.arange(len(valid_curve)) * 1e-15)
            best_train = min(train_curve)
        else:
            best_val_epoch = np.argmin(
                np.array(valid_curve) - np.arange(len(valid_curve)) * 1e-15)
            best_train = min(train_curve)

        print(
            f'Best @{best_val_epoch} validation score: {valid_curve[best_val_epoch]:.4f} Test score: {test_curve[best_val_epoch]:.4f}', flush=True
        )
        outs.append([
            best_val_epoch, valid_curve[best_val_epoch],
            test_curve[best_val_epoch]
        ])
    print(outs)
    print(f"all runs: ", end=" ")
    for _ in np.average(outs, axis=0):
        print(_, end=" ")
    for _ in np.std(outs, axis=0):
        print(_, end=" ")
    print()


if __name__ == "__main__":
    main()
    print("", end="", flush=True)
    os._exit(os.EX_OK)