import numpy as np
import sys
sys.path.append(".")
from Data.GM12878_DataModule import GM12878Module
from Models.hicsr import Generator, Discriminator, FeatureReconstructionLoss
from torch.optim import Adam
import torch
import torch.nn as nn
from tqdm import tqdm

def train():
    dm  = GM12878Module(batch_size=8, piece_size=256)
#     dm.prepare_data()
    dm.setup(stage='40')
    device = 'cuda:0'

    print("Build model")
    G_model = Generator(num_res_blocks = 15)
    D_model = Discriminator()

    G_model.init_params()
    D_model.init_params()

    print("Train model")
    epochs = 500
    G_optimizer = Adam(G_model.parameters(), lr=1e-5)
    D_optimizer = Adam(D_model.parameters(), lr=1e-5)
    
    G_model.to(device)
    D_model.to(device)
    adv_loss = nn.BCEWithLogitsLoss().to(device)
    l1_loss = nn.L1Loss().to(device)
    feature_reconstruction_loss = FeatureReconstructionLoss().to(device)


    for epoch in range(epochs):
        G_epoch_loss = 0
        G_adv_epoch_loss = 0
        G_image_epoch_loss = 0
        G_feature_epoch_loss = 0

        D_epoch_loss = 0
        D_real_epoch_loss = 0
        D_fake_epoch_loss = 0
        D_real_epoch_acc = 0
        D_fake_epoch_acc = 0

        G_model.train()
        D_model.train()

        for i, data in enumerate(tqdm(dm.train_dataloader())):
            # data[0] = low, data[1] = high, data[2] = chrm 

            low = data[0].to(device)
            high = data[1][:,:,6:-6,6:-6].to(device)

            # train G
            output = G_model(low)
            image_loss = l1_loss(output,high)
            feature_loss = sum(feature_reconstruction_loss(output,high))

            pred_fake = D_model(output)
            labels_real = torch.ones_like(pred_fake, requires_grad=False).to(device)
            GAN_loss = adv_loss(pred_fake, labels_real)

            total_loss_G = 2.5e-3 * GAN_loss + image_loss + feature_loss

            G_optimizer.zero_grad()
            total_loss_G.backward()    
            G_optimizer.step()

            G_epoch_loss += total_loss_G.item()
            G_adv_epoch_loss += GAN_loss.item()
            G_image_epoch_loss += image_loss.item()
            G_feature_epoch_loss += feature_loss.detach().item()

            # train D
            pred_real = D_model(high)
            labels_real = torch.ones_like(pred_real, requires_grad=False).to(device)
            pred_labels_real = (pred_real>0.5).float().detach()
            acc_real = (pred_labels_real == labels_real).float().sum()/labels_real.shape[0]
            loss_real = adv_loss(pred_real, labels_real)

            output = G_model(low)
            pred_fake = D_model(output.detach())
            labels_fake = torch.zeros_like(pred_fake, requires_grad=False).to(device)
            pred_labels_fake = (pred_fake>0.5).float().detach()
            acc_fake = (pred_labels_fake == labels_fake).float().sum()/labels_fake.shape[0]
            loss_fake = adv_loss(pred_fake, labels_fake)
            total_loss_D = loss_fake + loss_real

            D_optimizer.zero_grad()
            total_loss_D.backward()
            D_optimizer.step()

            D_epoch_loss += total_loss_D.item()
            D_real_epoch_loss += loss_real.item()
            D_fake_epoch_loss += loss_fake.item()
            D_real_epoch_acc += acc_real.item()
            D_fake_epoch_acc += acc_fake.item()

        print(f'{epoch:03d}', G_epoch_loss / i, G_adv_epoch_loss /i, G_image_epoch_loss/i, G_feature_epoch_loss/i,
              D_epoch_loss/i, D_real_epoch_loss/i, D_fake_epoch_loss/i, D_real_epoch_acc /i, D_fake_epoch_acc /i)
        
        if epoch % 5 == 0:
            val_loss_mean = 0
            with torch.no_grad():
                G_model.eval()
                for j, data in enumerate(dm.val_dataloader()):
                    
                    low = data[0].to(device)
                    high = data[1][:,:,6:-6,6:-6].to(device)

                    output = G_model(low).detach()
                    image_loss = l1_loss(high, output)
                    val_loss_mean += image_loss.item()

                print(f'{epoch:03d}', val_loss_mean / j)

        if epoch % 10 == 0:
            torch.save({'G_model_state_dict': G_model.state_dict(),
                       'D_model_state_dict': D_model.state_dict(),
                       'G_optimizer_state_dict': G_optimizer.state_dict(),
                       'D_optimizer_state_dict': D_optimizer.state_dict()
                       },'./Trained_Models/HiCSR'+str(epoch)+'.ckpt')

if __name__ == '__main__':
    train()