import numpy as np
import sys
sys.path.append(".")
from Data.GM12878_DataModule import GM12878Module
from Models.deephic import Generator, Discriminator, GeneratorLoss
from torch.optim import Adam
import torch
import torch.nn as nn
from tqdm import tqdm

def train():
    dm  = GM12878Module(batch_size=16, piece_size=256)
    dm.prepare_data()
    dm.setup(stage='40')
    device = 'cuda:0'

    print("Build model")
    netG = Generator(scale_factor = 1, in_channel=1, resblock_num=5).to(device)
    netD = Discriminator(in_channel=1).to(device)

    print("Train model")
    epochs = 200
    optimizerG = Adam(netG.parameters(), lr=1e-4)
    optimizerD = Adam(netD.parameters(), lr=1e-4)

    criterionG = GeneratorLoss().to(device)
    criterionD = torch.nn.BCELoss().to(device)

    l1_loss = nn.L1Loss().to(device)

    for epoch in range(epochs):
        g_loss_mean = 0
        d_loss_mean = 0

        netG.train()
        netD.train()

        for i, data in enumerate(tqdm(dm.train_dataloader())):
            # data[0] = low, data[1] = high, data[2] = chrm 

            low = data[0].to(device)
            high = data[1].to(device)
        
            # Train G
            optimizerG.zero_grad()
            fake_img = netG(low)
            fake_out = netD(fake_img)
            g_loss = criterionG(fake_out.mean(), fake_img, high)
            g_loss.backward()
            optimizerG.step()

            # Train D
            optimizerD.zero_grad()
            real_out = netD(high)
            fake_img = netG(low)
            fake_out = netD(fake_img.detach())
            d_loss_real = criterionD(real_out, torch.ones_like(real_out))
            d_loss_fake = criterionD(fake_out, torch.zeros_like(fake_out))
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            g_loss_mean += g_loss.item()
            d_loss_mean += d_loss.item()

        val_loss_mean = 0

        netG.eval()
        netD.eval()


        for j, data in enumerate(dm.val_dataloader()):
            low = data[0].to(device)
            high = data[1].to(device)

            output = netG(low)
            val_loss = l1_loss(output, high)

            val_loss_mean += val_loss.item()

        print(f'{epoch:03d}', g_loss_mean / i, d_loss_mean / i, val_loss_mean / j)
        
        if epoch % 10 == 0:
            torch.save(netG.state_dict(), './Trained_Models/Deephic'+str(epoch)+'.ckpt')

if __name__ == '__main__':
    train()