import importlib
import os
import sys
import time

import torch
import torch.nn as nn

import datasets

from args import parse_args

from util.general_util import (
    parse_configs_file,
    setup_seed,
    get_parameter_number,
)
from models.cvae import cvae, loss_function

import matplotlib.pyplot as plt

def train(epoch, cvae, optimizer, criterion, data):
    cvae.train()
    loss_val = 0.0
    reconstruct_loss = 0.0
    kld_loss = 0.0
    
    pred, mu, logvar = cvae(data)
    loss, MSE, KLD = criterion(pred, data.x, mu, logvar)
    loss_val += loss.item()
    reconstruct_loss += MSE.item()
    kld_loss += KLD.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 
    print(f'Epoch:{epoch}, loss:{loss_val / len(data.x)}, reconstruction loss:{reconstruct_loss / len(data.x)}, KLD:{kld_loss/ len(data.x)}')
    return loss_val / len(data.x)

def main():
    args = parse_args()
    if args.configs is not None:
        parse_configs_file(args)
        
        
    setup_seed(args.gtseed)
    device = "cuda:1" if torch.cuda.is_available() else "cpu" 
        
    D = datasets.__dict__[args.dataset](args)   
    dataset = D.data_loaders(mapping_function = args.mapping_function)
    dataset = dataset.to(device)

    setup_seed(args.seed)
    
    model = cvae(args.sensitive_size, args.hidden_channels, args.num_layers, args.A, args.input_size, args.latent_dim, args.z_size)
    model.gnn.load_state_dict(torch.load('weights/gnn.pt'))
    for n, param in model.named_modules():
        if 'gnn' in n:
            param.requires_grad_(False)
        
        
    model = model.to(device)
    
    
    # auto grad
    optimizer = torch.optim.Adam(model.parameters(), lr= 1e-2)
    criterion = loss_function()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.cvae_num_epoch)
    
    loss_val = []
    for epoch in range(args.cvae_num_epoch):
        loss = train(epoch, model, optimizer, criterion, dataset)
        loss_val.append(loss)
        scheduler.step()
    
    torch.save(model.state_dict(), 'weights/cvae_training.pt')
    plt.plot(range(args.cvae_num_epoch), loss_val)
    plt.savefig('loss_training')
    
    
if __name__ == '__main__':
    main()