"""
script to train on ZINC task.
"""
import os
import torch.cuda
import torch.nn as nn
from torch_geometric.datasets import ZINC
import train_utils

from torch_geometric.loader import DataLoader
from models.model_construction import make_decoder

from training_evaluation import *
from data import DataPreTransform, CustomGraphDataset
from models.GlobalGIN import GlobalGIN

from compute_alternative_targets import normalize_alternative_targets


def main():
    parser = train_utils.args_setup()
    parser.add_argument('--dataset_name', type=str, default="ZINC", help='Name of dataset.')
    parser.add_argument('--runs', type=int, default=10, help='Number of repeat run.')
    parser.add_argument('--full', action="store_true", help="If true, run ZINC full." )
    args = parser.parse_args()
    args = train_utils.update_args(args)

    os.makedirs('checkpoints', exist_ok=True)
    os.makedirs(f"checkpoints/{args.checkpoint_folder}", exist_ok=True)
    os.makedirs(f"checkpoints/{args.checkpoint_folder}_finetune", exist_ok=True)

    os.makedirs('plots', exist_ok=True)
    os.makedirs(f"plots/{args.checkpoint_folder}", exist_ok=True)
    os.makedirs(f"plots/{args.checkpoint_folder}_finetune", exist_ok=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    args.mode = "min"

    if args.full:
        args.exp_name = "full_" + args.exp_name

    path, pre_transform, follow_batch = train_utils.data_setup(args)

    pre_transform = DataPreTransform(args)


    train_dataset = ZINC(path,
                         subset=not args.full,
                         split="train",
                         pre_transform=pre_transform)

    val_dataset = ZINC(path,
                       subset=not args.full,
                       split="val",
                       pre_transform=pre_transform)

    test_dataset = ZINC(path,
                        subset=not args.full,
                        split="test",
                        pre_transform=pre_transform)

    
    if args.normalize_alt_targets:
        print("Normalizing alternative targets...")
        train_dataset = CustomGraphDataset(normalize_alternative_targets(train_dataset, args))
        val_dataset = CustomGraphDataset(normalize_alternative_targets(val_dataset, args))
        test_dataset = CustomGraphDataset(normalize_alternative_targets(test_dataset, args))

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)

    data_sample = train_dataset[0]
    emb_size = data_sample.x.shape[-1]


    alt_node_target_sizes = {}
    alt_graph_target_sizes = {}

    if args.predict_alt_targets:
        for key in args.lambda_alt_targets.keys():
            if args.lambda_alt_targets[key] > 0:
                if key in ["cycle_target", "lap_eval_target"]:
                    alt_graph_target_sizes[key] = getattr(data_sample, key).shape[-1]
                else:
                    alt_node_target_sizes[key] = getattr(data_sample, key).shape[-1]

    

    if args.pre_train:
        pretrain_model = GlobalGIN(4, 3, 5, emb_size, 60, 6, 40, 0.1, True, "Sum",  device, alt_node_target_sizes, alt_graph_target_sizes, model_name=args.model_name, head_type=args.head_type).to(device) 
        optimizer = torch.optim.Adam(pretrain_model.parameters(), lr=args.lr)
    
        print("Commencing Pre-Training")
        training_loop(pretrain_model, train_loader, val_loader, test_loader, optimizer, device, args)
    gnn_model = pretrain_model.GIN
    finetune_model = make_decoder(args, gnn_model)
    optimizer = torch.optim.Adam(finetune_model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, mode=args.mode, factor=args.factor, patience=args.patience, min_lr=args.min_lr
                )

    criterion = nn.L1Loss()

    print("Commencing Fine Tuning")
    finetune_loop(finetune_model, train_loader, val_loader, test_loader, optimizer, device, args, criterion, scheduler)


    return


if __name__ == "__main__":
    main()
