import numpy as np
import sys
sys.path.append(".")
from Data.GM12878_DataModule import GM12878Module
from Models.hicplus import Net
from torch.optim import Adam
import torch
import torch.nn as nn

def train():
    dm  = GM12878Module(batch_size=512, piece_size=256)
    dm.prepare_data()
    dm.setup(stage='40')

    print("Build model")
    model = Net()
    
    print("Train model")
    epochs = 300
    optimizer = Adam(model.parameters(), lr=3e-5)
    criterion = nn.MSELoss()
    device = 'cuda:0'

    model.to(device)

    for epoch in range(epochs):
        loss_mean = 0
        model.train()

        for i, data in enumerate(dm.train_dataloader()):
            # data[0] = low, data[1] = high, data[2] = chrm 

            low = data[0].to(device)
            high = data[1][:,:,6:-6,6:-6].to(device)
            outputs = model(low)

            loss = criterion(outputs, high)
            optimizer.zero_grad()
            loss.backward()    
            optimizer.step()
            loss_mean += loss.item()

        val_loss_mean = 0
        for j, data in enumerate(dm.val_dataloader()):
            model.eval()
            low = data[0].to(device)
            high = data[1][:,:,6:-6,6:-6].to(device)
            outputs = model(low)

            val_loss = criterion(outputs, high)
            val_loss_mean += val_loss.item()

        print(f'{epoch:03d}', loss_mean / i, val_loss_mean / j)
        
        if epoch % 10 == 0:
            torch.save(model.state_dict(), './Trained_Models/HiCPlus'+str(epoch)+'.ckpt')

if __name__ == '__main__':
    train()