import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from models.get_model import get_py_module
from models.get_model import get_model
from utils.dataset import *
from utils.PUNet.dataset import get_loader as get_PUNet_loader
from utils.PUNet.PUNet_utils import *
from utils.loss import cd_loss



from tqdm import tqdm



def train(model, optimizer, train_loader):
    train_rec = []
    for data in tqdm(train_loader,
                             total=len(train_loader),
                             desc="Training"):

        for key in data:
            data[key] = data[key].cuda()

        data1 = {'pcl_noisy': data['pcl_noisy1']}
        data2 = {'pcl_noisy': data['pcl_noisy2']}

        optimizer.zero_grad()
        out1 = model(data1)
        out2 = model(data2)

        branch1_pc = [data1['pcl_noisy']] + out1['denoised']
        branch2_pc = [data2['pcl_noisy']] + out2['denoised']

        loss_SR_total = 0
        loss_MPCL_total = 0

        for i in range(len(branch1_pc)-1):

            loss_SR = (cd_loss(branch1_pc[i], branch2_pc[i + 1]) +
                       cd_loss(branch2_pc[i], branch1_pc[i + 1]) +
                       cd_loss(branch1_pc[i + 1], branch2_pc[i + 1]))
            loss_MPCL = (torch.nn.functional.mse_loss(out1['denoised_MP'][i], branch1_pc[i + 1], reduction='sum') +
                         torch.nn.functional.mse_loss(out2['denoised_MP'][i], branch2_pc[i + 1], reduction='sum'))
            loss_SR_total += loss_SR
            loss_MPCL_total += loss_MPCL


        loss_total = loss_SR_total + loss_MPCL_total
        loss = loss_total

        loss.backward()
        optimizer.step()

    return train_rec



def main(cfg):

    model = get_model(cfg.hp.model_name, cfg.hp.cfg_name).cuda()

    train_loader, test_loader = get_PUNet_loader(cfg.hp)

    print('PUNet_loader_len:', len(train_loader))

    model = model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.hp.learning_rate)

    for epoch in range(1, cfg.hp.max_epoch + 1):

        _ = train(model=model, optimizer=optimizer, train_loader=train_loader)


if __name__ == '__main__':

    cfg_name = 'config.SIMPC'
    config = get_py_module(cfg_name)

    main(config)