import torch, os, random
import numpy as np
from tqdm import tqdm
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from dataloader import EPNDataLoader
from config_3depn import params
from models.Model import Model
from cuda.ChamferDistance import L2_ChamferDistance


def set_seed(seed=42):
    if seed is not None:
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        # some cudnn methods can be random even after fixing the seed
        # unless you tell it to be deterministic
        # torch.backends.cudnn.deterministic = True


def main():
    # default setting
    cfg = params()
    MODEL = 'work3'
    FLAG = 'train_3depn'
    CLASS = 'car'
    BATCH_SIZE = int(cfg.batch_size)

    # create ckpt_dir
    ckpt_dir = f'ckpt_3depn'
    if not os.path.exists(os.path.join(ckpt_dir)):
        os.makedirs(os.path.join(ckpt_dir))

    # create models
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Model()
    model = torch.nn.DataParallel(model)
    model.to(device)

    # loss function
    loss_cd = L2_ChamferDistance()

    # dataset loading
    EPNDataset_train = EPNDataLoader('./dataset/3depn_train_list.txt',
                                     data_path=cfg.data_root,
                                     status="train",
                                     category=CLASS)
    train_loader = DataLoader(EPNDataset_train,
                              batch_size=cfg.batch_size,
                              num_workers=cfg.nThreads,
                              shuffle=True,
                              drop_last=False)

    # optimizer setting
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = MultiStepLR(optimizer, milestones=cfg.milestones, gamma=0.7)

    # saving hyperparameters
    CONFIG_FILE = f'ckpt_3depn/CONFIG.txt'
    with open(CONFIG_FILE, 'w') as f:
        f.write('MODEL:' + str(MODEL) + '\n')
        f.write('FLAG:' + str(FLAG) + '\n')
        f.write('CLASS:' + str(CLASS) + '\n')
        f.write('BATCH_SIZE:' + str(BATCH_SIZE) + '\n')
        f.write('MAX_EPOCH:' + str(int(cfg.n_epochs)) + '\n')
        f.write(str(cfg.__dict__))

    # training
    set_seed()
    for epoch in range(1, cfg.n_epochs+1):
        model.train()
        n_batches = len(train_loader)
        with tqdm(total=n_batches, desc=f'Epoch {epoch}/{cfg.n_epochs}', unit='batch') as pbar:
            for batch_idx, data in enumerate(train_loader):
                image = data[0].to(device)
                partial = data[2].to(device)
                gt = data[1].to(device)
                out = model(partial, image)
                loss_stage0 = loss_cd(gt, out[0])
                loss_stage1 = loss_cd(gt, out[1])
                loss_stage2 = loss_cd(gt, out[2])
                loss_stage3 = loss_cd(gt, out[3])
                loss = loss_stage0 + loss_stage1 + loss_stage2 + loss_stage3
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.set_postfix(loss=1e3 * loss_stage3.item())
                pbar.update(1)
        print('lr: ', optimizer.state_dict()['param_groups'][0]['lr'])
        scheduler.step()
    torch.save({'model_state_dict': model.state_dict()},
               f'ckpt_3depn/car.pt')

if __name__ == '__main__':
    main()