import numpy as np
import sys
sys.path.append(".")
from Data.GM12878_DataModule import GM12878Module
from Models.hicsr import DAE
from torch.optim import Adam
import torch
import torch.nn as nn
from tqdm import tqdm

def train():
    dm  = GM12878Module(batch_size=4096, piece_size=256)
    dm.prepare_data()
    dm.setup(stage='40')
    device = 'cuda:1'

    print("Build model")
    model = DAE()

    print("Train model")
    epochs = 600
    optimizer = Adam(model.parameters(), lr=1e-4)
    
    model.to(device)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        loss_mean = 0
        model.train()

        for i, data in enumerate(tqdm(dm.train_dataloader())):
            # data[0] = low, data[1] = high, data[2] = chrm 
            high = data[1].to(device)
            high_noisy = high + 0.1 * torch.randn_like(high)
            high_noisy = torch.clamp(high_noisy, -1, 1)

            # train G
            output = model(high_noisy)

            loss = criterion(output, high)
            optimizer.zero_grad()
            loss.backward()    
            optimizer.step()

            loss_mean += loss.item()
            
        val_loss_mean = 0
        with torch.no_grad():
            for j, data in enumerate(dm.val_dataloader()):
                model.eval()
                high = data[1].to(device)
                high_noisy = high + 0.1 * torch.randn_like(high)
                high_noisy = torch.clamp(high_noisy, -1, 1)

                output = model(high_noisy)
                val_loss = criterion(high, output)
                val_loss_mean += val_loss.item()

            print(f'{epoch:03d}', loss_mean / i, val_loss_mean / j)
        
        if epoch % 10 == 0:
            torch.save(model.state_dict(), './Trained_Models/DAE'+str(epoch)+'.ckpt')

if __name__ == '__main__':
    train()