import copy
import os

import hydra
import numpy as np
import torch
import wandb
from omegaconf import DictConfig
from torch import optim
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from models import UnsupervisedGNN
from trainer import ProbsTrainer
from utils.experiment import save_run_config, setup_wandb, count_parameters, get_data

torch.set_float32_matmul_precision('high')


@hydra.main(version_base=None, config_path='./config', config_name="mpnn")
def main(args: DictConfig):
    log_folder_name = save_run_config(args)
    setup_wandb(args)

    train_set, valid_set, test_set = get_data(args.train.datapath)

    if args.train.debug:
        train_set = train_set[:20]
        valid_set = valid_set[:20]
        test_set = test_set[:20]

    train_loader = DataLoader(train_set,
                              batch_size=args.train.batchsize,
                              shuffle=True,
                              pin_memory=True)
    val_loader = DataLoader(valid_set,
                            batch_size=args.train.batchsize * 2,
                            shuffle=False,
                            pin_memory=True)
    test_loader = DataLoader(test_set,
                             batch_size=args.train.batchsize * 2,
                             shuffle=False,
                             pin_memory=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    test_open_losses = []
    test_trans_losses = []
    test_total_losses = []

    for run in range(args.train.runs):
        model = UnsupervisedGNN(hid_dim=args.gnn.hidden,
                                num_encode_layers=args.gnn.num_encode_layers,
                                num_conv_layers=args.gnn.num_conv_layers,
                                edge_encode_layers=args.gnn.edge_encode_layers,
                                gnn_mlp_layers=args.gnn.gnn_mlp_layers,
                                num_pred_layers=args.gnn.num_pred_layers,
                                aggr=args.gnn.aggr,
                                square_dist=args.gnn.square_dist).to(device)
        best_model = copy.deepcopy(model.state_dict())

        optimizer = optim.Adam(model.parameters(), lr=args.train.lr, weight_decay=args.train.weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         mode='min',
                                                         factor=0.5,
                                                         patience=int(args.train.patience * 0.6),
                                                         min_lr=1.e-5)

        trainer = ProbsTrainer(args.train.accum)

        pbar = tqdm(range(args.train.epoch))
        for epoch in pbar:
            train_loss = trainer.train(train_loader, model, optimizer, device).item()
            val_open_loss, val_trans_loss, val_total_loss = trainer.eval(val_loader, model, device)

            if scheduler is not None:
                scheduler.step(val_total_loss)

            if trainer.best_loss > val_total_loss:
                trainer.patience = 0
                trainer.best_loss = val_total_loss
                best_model = copy.deepcopy(model.state_dict())
                if args.train.ckpt:
                    torch.save(model.state_dict(), os.path.join(log_folder_name, f'best_model{run}.pt'))
            else:
                trainer.patience += 1

            if trainer.patience > args.train.patience:
                break

            stats_dict = {'train_loss': train_loss,
                          'val_open_loss': val_open_loss,
                          'val_trans_loss': val_trans_loss,
                          'val_total_loss': val_total_loss,
                          'lr': scheduler.optimizer.param_groups[0]["lr"]}

            pbar.set_postfix(stats_dict)
            wandb.log(stats_dict)

        model.load_state_dict(best_model)
        test_open_loss, test_trans_loss, test_total_loss = trainer.eval(test_loader, model, device)

        test_open_losses.append(test_open_loss)
        test_trans_losses.append(test_trans_loss)
        test_total_losses.append(test_total_loss)

    wandb.log({
        'num_params': count_parameters(model),
        'test_open_loss_mean': np.mean(test_open_losses),
        'test_open_loss_std': np.std(test_open_losses),
        'test_tran_loss_mean': np.mean(test_trans_losses),
        'test_tran_loss_std': np.std(test_trans_losses),
        'test_total_loss_mean': np.mean(test_total_losses),
        'test_total_loss_std': np.std(test_total_losses)
    })


if __name__ == '__main__':
    main()
