"""
script to train on QM9 targets.
"""

import torch
import torch.nn as nn
from torch import Tensor
from datasets.QM9Dataset import QM9, conversion
import train_utils
from torchmetrics import MeanAbsoluteError
from torchmetrics.functional.regression.mae import _mean_absolute_error_compute
import torch_geometric.transforms as T
from torch_geometric.data import Data
from models.model_construction import make_decoder

from training_evaluation import *
import os
import torch.cuda
from torch_geometric.loader import DataLoader

from data import DataPreTransform
from models.GlobalGIN import GlobalGIN




class InputTransform(object):
    """QM9 input feature transformation. Concatenate x and z together.
    """
    def __init__(self):
        super().__init__()

    def __call__(self,
                 data: Data) -> Data:
        x = data.x
        z = data.z
        data.x = torch.cat([z.unsqueeze(-1), x], dim=-1)
        data.edge_attr = torch.where(data.edge_attr == 1)[-1]
        return data


class MeanAbsoluteErrorQM9(MeanAbsoluteError):
    def __init__(self,
                 std,
                 conversion,
                 **kwargs):
        super().__init__(**kwargs)
        self.std = std
        self.conversion = conversion

    def compute(self) -> Tensor:
        return (_mean_absolute_error_compute(self.sum_abs_error, self.total) * self.std) / self.conversion
def pre_filter(data):
    if data.num_nodes < 6:
        return False
    else:
        return True

def main():
    # pretrain_model = GlobalGIN(4, 3, 5, 41, 60, 6, 40, 0.1, True, "Sum",  "cpu", model_name=args.model_name)
    # gnn = pretrain_model.GIN
    # print(sum(p.numel() for p in gnn.parameters() if p.requires_grad))

    parser = train_utils.args_setup()
    parser.add_argument('--dataset_name', type=str, default="QM9", help='Name of dataset.')
    parser.add_argument('--task', type=int, default=11, choices=list(range(19)), help='Train target.')
    parser.add_argument('--search', action="store_true", help="If true, run all first 12 targets.")

    args = parser.parse_args()
    args = train_utils.update_args(args, add_task=False)
    path, pre_transform, follow_batch = train_utils.data_setup(args)

    pre_transform = DataPreTransform(args)

    os.makedirs('checkpoints', exist_ok=True)
    os.makedirs(f"checkpoints/{args.checkpoint_folder}", exist_ok=True)
    os.makedirs(f"checkpoints/{args.checkpoint_folder}_finetune", exist_ok=True)

    os.makedirs('plots', exist_ok=True)
    os.makedirs(f"plots/{args.checkpoint_folder}", exist_ok=True)
    os.makedirs(f"plots/{args.checkpoint_folder}_finetune", exist_ok=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.mode = "min"

    dataset = QM9(path,
                          pre_transform=T.Compose([InputTransform(), pre_transform]),
                          transform=train_utils.PostTransform(args.wo_node_feature,
                                                              args.wo_edge_feature,
                                                              args.task
                                                              ))
    if not args.use_shuffle or not os.path.exists("perm.pt"):
        dataset, perm = dataset.shuffle(True)
        torch.save(perm, "perm.pt")
    else:
        perm = torch.load("perm.pt")
        dataset = dataset[perm]

    tenpercent = int(len(dataset) * 0.1)
    mean = dataset.data.y[tenpercent:].mean(dim=0)
    std = dataset.data.y[tenpercent:].std(dim=0)
    dataset.data.y = (dataset.data.y - mean) / std

    train_dataset = dataset[2 * tenpercent:]
    test_dataset = dataset[:tenpercent]
    val_dataset = dataset[tenpercent:2 * tenpercent]

    pre_train_mask = [pre_filter(data) for data in dataset]
    train_mask = pre_train_mask[2 * tenpercent:]
    test_mask = pre_train_mask[:tenpercent]
    val_mask = pre_train_mask[tenpercent: 2 * tenpercent]

    data_sample = train_dataset[0]
    emb_size = data_sample.x.shape[-1]
    pretrain_model = GlobalGIN(4, 3, 5, emb_size, 60, 6, 40, 0.1, True, "Sum",  device, model_name=args.model_name, head_type=args.head_type).to(device)


    if args.num_epochs_pre > 0:
        pre_train_dataset = train_dataset[train_mask]
        pre_test_dataset = test_dataset[test_mask]
        pre_val_dataset = val_dataset[val_mask]

        

        pre_train_loader = DataLoader(pre_train_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)
        pre_val_loader = DataLoader(pre_val_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)
        pre_test_loader = DataLoader(pre_test_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)

        pretrain_model = GlobalGIN(4, 3, 5, emb_size, 60, 6, 40, 0.1, True, "Sum",  device, model_name=args.model_name, head_type=args.head_type).to(device)

        optimizer = torch.optim.Adam(pretrain_model.parameters(), lr=args.lr)
    
        print("Commencing Pre-Training")
        training_loop(pretrain_model, pre_train_loader, pre_val_loader, pre_test_loader, optimizer, device, args)

    # pretrain_model.load_state_dict(torch.load(f"checkpoints/{args.checkpoint_folder}/71.pt", weights_only=True))


    if args.search:
        for target in range(12):
            args.task = target
            dataset = QM9(path,
                          pre_transform=T.Compose([InputTransform(), pre_transform]),
                          transform=train_utils.PostTransform(args.wo_node_feature,
                                                              args.wo_edge_feature,
                                                              args.task
                                                              ))
            dataset = dataset[perm]
            tenpercent = int(len(dataset) * 0.1)
            mean = dataset.data.y[tenpercent:].mean(dim=0)
            std = dataset.data.y[tenpercent:].std(dim=0)
            dataset.data.y = (dataset.data.y - mean) / std

            train_dataset = dataset[2 * tenpercent:]
            test_dataset = dataset[:tenpercent]
            val_dataset = dataset[tenpercent:2 * tenpercent]

            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)
            val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)
            test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers = args.num_workers, follow_batch = follow_batch)

            loss_cri = nn.MSELoss()
            evaluator = MeanAbsoluteErrorQM9(std[args.task].item(), conversion[args.task].item()).to(device)
            
            gnn_model = pretrain_model.GIN
            finetune_model = make_decoder(args, gnn_model)
            optimizer = torch.optim.Adam(finetune_model.parameters(), lr=args.lr)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, mode="min", factor=args.factor, patience=args.patience, min_lr=args.min_lr
                )

            print("Commencing Fine Tuning")
            finetune_loop(finetune_model, train_loader, val_loader, test_loader, optimizer, device, args, loss_cri, scheduler, target, evaluator)

        return


if __name__ == "__main__":
    main()


