import argparse
from arguments import parse_arguments
from model.mol_graph import MolGraph
from model.CQVAE import CQ_VAE, VAE_Output
from model.dataset import MyDataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import random
import os.path as path
from tensorboardX import SummaryWriter
from datetime import datetime
from model.mydataclass import ModelParams, TrainingParams, batch_train_data, PathTool
from model.scheduler import beta_annealing_schedule
from typing import List
from benchmark import benchmark
import torch.multiprocessing as mp

def train(
    model: CQ_VAE,
    training_params: TrainingParams,
    pathtool: PathTool, 
):

    optimizer = optim.Adam(model.parameters(), lr=training_params.lr)
    total_step, beta = 0, training_params.beta_min
    tb = SummaryWriter(log_dir=pathtool.tensorboard_dir)


    if pathtool.load_model_path is not None:
        model_state, _, _, _ = torch.load(pathtool.load_model_path)
        model.load_state_dict(model_state)
        print(f"Load model from {pathtool.load_model_path}.\n")

    scheduler = lr_scheduler.ExponentialLR(optimizer, training_params.lr_anneal_rate)
    beta_scheduler = beta_annealing_schedule(params=training_params, init_beta=beta, init_step=total_step)
    train_dataset: List[batch_train_data] = MyDataLoader(pathtool.train_processed_path)
    print(f"The training set has {len(train_dataset)} batches. Run {training_params.epoch} epochs. Totally {len(train_dataset) * training_params.epoch} steps.")
    if path.exists(pathtool.valid_processed_path):
        dev_dataset: List[batch_train_data] = MyDataLoader(pathtool.valid_processed_path)
    else:
        dev_dataset = None

    print(f"[{datetime.now()}] Begin training...")
    for epoch in range(training_params.epoch):

        print(f"\r[{datetime.now()}] Epoch {epoch}/{training_params.epoch}. {len(train_dataset)} steps in each epoch.")
        epoch_loss_list: List[VAE_Output] = []

        for input in train_dataset:

            total_step += 1
            model.zero_grad()

            input = input.cuda()
            output: VAE_Output = model(input, beta=beta, prop_weight=training_params.prop_weight)

            output.total_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), training_params.grad_clip_norm)
            
            optimizer.step()
            epoch_loss_list.append(output)

            output.log_tb_results("Train", total_step, tb, beta, scheduler.get_last_lr()[0])

            if total_step % 50 == 0:
                output.print_results(total_step, lr=scheduler.get_last_lr()[0], beta=beta)

            if total_step % training_params.lr_anneal_iter == 0:
                scheduler.step()

            beta = beta_scheduler.step()

        model.eval()
        model.zero_grad()
        torch.cuda.empty_cache()
        with torch.no_grad():
            ckpt = (model.state_dict(), optimizer.state_dict(), total_step, beta)
            torch.save(ckpt, pathtool.model_save_path)
            model.save_motifs_embed(pathtool.motifs_embed_save_path)
            model.load_motifs_embed(pathtool.motifs_embed_save_path)
            dev_loss_list: List[VAE_Output] = []
            for input in dev_dataset:
                input = input.cuda()
                dev_output: VAE_Output = model(input, beta=training_params.beta_max, prop_weight=training_params.prop_weight, dev=True)
                dev_loss_list.append(dev_output)
                
            dev_loss = VAE_Output.log_epoch_results(epoch, epoch_loss_list, dev_loss_list)
            
            if epoch == 0:
                best_dev_loss = dev_loss
                best_epoch = 0
            if epoch ==0 or dev_loss < best_dev_loss:
                best_dev_loss = dev_loss
                best_epoch = epoch
                torch.save(ckpt, pathtool.best_model_save_path)
                model.save_motifs_embed(pathtool.best_motifs_embed_save_path)
                print(f"[Epoch {epoch}] is the best model. Save the checkpoint.\n")
            del model.decoder.motif_node_embed, model.decoder.motif_graph_embed
        model.train()
    
    print(f"[{datetime.now()}] Training completed. Epoch {best_epoch} is the best model.\n")

if __name__ == "__main__":

    args = parse_arguments()
    training_params = TrainingParams.from_arguments(args)
    model_params = ModelParams.from_arguments(args)
    pathtool = PathTool.from_arguments(args)

    mp.set_start_method("spawn")
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    MolGraph.load_vocab(pathtool.vocab_path)

    print(f"[{datetime.now()}] Loading CQ-VAE Model.")
    model = CQ_VAE.load_model(
        model_params = model_params,
        pathtool = pathtool,
        training = True,
        )

    print(f"Total number of model parameters: {sum([x.nelement() for x in model.parameters()]) / 1000}K.")
    print(model_params)
    print(training_params)

    train(
        model = model,
        training_params = training_params,
        pathtool = pathtool 
    )
    model.cpu()
    del model
    
    benchmark_results = benchmark(
        model_params = model_params,
        pathtool = pathtool,
        num_samples = args.num_samples,
        num_workers = 1,
        return_results = True,
    )

    print(benchmark_results)