import sys
sys.path.append(".")

from Data.GM12878_DataModule import GM12878Module
import torch
from Models.hicbridge import Unet, GaussianDiffusion
import random
import numpy as np
from torch.optim import Adam
from tqdm import tqdm
import os 

def train():

    SEED = 0
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    bn = 4
    dm  = GM12878Module(batch_size= bn, piece_size=256)
    dm.prepare_data()
    dm.setup(stage='fit')

    model = Unet(
        dim = 64,
        dim_mults = (1, 1, 2, 2, 4, 4),
        channels = 1,
        self_condition = False
    )

    diffusion = GaussianDiffusion(
        model,
        image_size = 256,
        beta_schedule = 'linear', 
        timesteps = 1000,
        indi = True,
        objective = 'pred_x0',
        noise_schedule = 'brownian',
        indi_step_size = 1000,
        loss_type = 'l1' 
    )
    device = 'cuda:0'


    model.to(device= device)
    diffusion.to(device= device)

    print(len(dm.train_dataloader()))

    model.train()
    scaler = torch.cuda.amp.GradScaler()
    
    epochs = 151
    optimizer = Adam(model.parameters(), lr=1e-4)

    for epoch in range(epochs):

        loss_mean = 0

        for i, data in enumerate(tqdm(dm.train_dataloader())):
            with torch.cuda.amp.autocast():
                low = data[0].to(device)
                high = data[1].to(device)
                loss = diffusion(high, low)
                
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loss_mean += loss.item()

        print(f'{epoch:03d}', loss_mean / i)

        if epoch % 10 == 0:
            torch.save(model.state_dict(), './Trained_Models/hicbridge_'+str(epoch)+'.ckpt')

if __name__ == '__main__':

    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    train()
