import argparse
import os
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import datetime
import random
# import torch.nn as nn
import yaml
import torch
import torch.optim as Optim

from torch.utils.data.dataloader import DataLoader
from tensorboardX import SummaryWriter

from dataset import ShapeNet_Heart_Slice,ShapeNet_Heart_Slice_components
from models import PCN2Brunch, PCN6Brunch,PCN3Brunch
from metrics.metric import l1_cd
from metrics.loss import cd_loss_L1  # , emd_loss
from visualization import plot_pcd_one_view
import pandas as pd

def make_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def log(fd, message, time=True):
    if time:
        message = ' ==> '.join([datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), message])
    fd.write(message + '\n')
    fd.flush()
    print(message)


def prepare_logger(params):
    # prepare logger directory
    make_dir(params.log_dir)
    make_dir(os.path.join(params.log_dir, params.exp_name))

    logger_path = os.path.join(params.log_dir, params.exp_name, params.category)
    ckpt_dir = os.path.join(params.log_dir, params.exp_name, params.category, 'checkpoints')
    epochs_dir = os.path.join(params.log_dir, params.exp_name, params.category, 'epochs')

    make_dir(logger_path)
    make_dir(ckpt_dir)
    make_dir(epochs_dir)

    logger_file = os.path.join(params.log_dir, params.exp_name, params.category, 'logger.log')
    log_fd = open(logger_file, 'a')

    log(log_fd, "Experiment: {}".format(params.exp_name), False)
    log(log_fd, "Logger directory: {}".format(logger_path), False)
    log(log_fd, str(params), False)

    train_writer = SummaryWriter(os.path.join(logger_path, 'train'))
    val_writer = SummaryWriter(os.path.join(logger_path, 'val'))

    return ckpt_dir, epochs_dir, log_fd, train_writer, val_writer


def smooth_labels(labels, smoothing=0.1):
    return labels * (1 - smoothing) + 0.5 * smoothing


def train(params):
    torch.backends.cudnn.benchmark = True

    ckpt_dir, epochs_dir, log_fd, train_writer, val_writer = prepare_logger(params)

    log(log_fd, 'Loading Data...')

    train_dataset = ShapeNet_Heart_Slice_components('../data/CTA/cta_normal/pointcloud', 'train',
                                         params.category)
    val_dataset = ShapeNet_Heart_Slice_components('../data/CTA/cta_normal/pointcloud', 'valid',
                                       params.category)

    train_dataloader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True,
                                  num_workers=params.num_workers)
    val_dataloader = DataLoader(val_dataset, batch_size=params.batch_size, shuffle=False,
                                num_workers=params.num_workers)
    log(log_fd, "Dataset loaded!")

    # model
    # model = PCN2Brunch(num_dense=16384, latent_dim=1024, grid_size=4).to(params.device)
    model = PCN6Brunch(num_dense=16384, latent_dim=1024, grid_size=4).to(params.device)

    # D_model=SliceDiscriminator().to(params.device)

    # optimizer
    optimizer = Optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.999))
    # optimizer_d=Optim.Adam(D_model.parameters(), lr=params.lr, betas=(0.9, 0.999))
    # optimizer_g=Optim.Adam(D_model.parameters(), lr=params.lr, betas=(0.9, 0.999))
    lr_schedual = Optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.7)
    criterion = torch.nn.MSELoss()

    step = len(train_dataloader) // params.log_frequency

    best_cd_l1 = 1e8
    best_epoch_l1 = -1
    train_step, val_step = 0, 0
    start_epoch = 1
    loss_csv_path = os.path.join(ckpt_dir, 'epoch_loss.csv')
    if params.ckpt_path is not None:
        df_loss =pd.read_csv(loss_csv_path)
        epochL=df_loss['epoch'].tolist()
        slice_lossL=df_loss['slice loss'].tolist()
        shape_lossL=df_loss['shape loss'].tolist()
        componentrec_lossL=df_loss['component rec loss'].tolist()
    else:
        df_loss=pd.DataFrame()
        epochL=[]
        slice_lossL = []
        shape_lossL = []
        componentrec_lossL=[]
    # load pretrained model and optimizer
    if params.ckpt_path is not None:
        model.load_state_dict(torch.load(params.ckpt_path))
        # D_model.load_state_dict(torch.load(params.ckpt_path.replace('.pth','_d.pth')))

        print('imported', params.ckpt_path)

        with open(ckpt_dir + '/output.yaml', 'r') as file:
            data_read = yaml.safe_load(file)

        print('data_read:', data_read)
        best_cd_l1 = data_read['best_cd_l1']
        best_epoch_l1 = data_read['best_epoch_l1']
        train_step, val_step = data_read['train_step'], data_read['val_step']
        start_epoch = data_read['epoch']

    # training
    for epoch in range(start_epoch, params.epochs + 1):
        # hyperparameter alpha
        if epoch < 100:
            alpha = 0.01
        elif epoch < 200:
            alpha = 0.1
        # elif train_step < 5000:
        #     alpha = 0.5
        else:
            alpha = 0.5

        # training
        model.train()
        # D_model.train()
        for i, data_ in enumerate(train_dataloader):
            (p_slice, c_slice, c_shape,
             lv_pc,rv_pc,aro_pc,la_pc,
             ra_pc,myo_pc,path)=data_

            p_slice, c_slice, c_shape = p_slice.to(params.device), c_slice.to(params.device), c_shape.to(params.device)
            lv_pc, rv_pc,aro_pc = lv_pc.to(params.device),rv_pc.to(params.device),aro_pc.to(params.device)
            la_pc,ra_pc, myo_pc = la_pc.to(params.device), ra_pc.to(params.device), myo_pc.to(params.device)
            # print(path)

            optimizer.zero_grad()
            # optimizer_d.zero_grad()
            # optimizer_g.zero_grad()

            # forward propagation
            (rec_slice_pred,
             rotate_slice_pred,
             coarse_shape_pred,
             coarse_lv_pred,coarse_rv_pred,coarse_aro_pred,
             coarse_la_pred, coarse_ra_pred, coarse_myo_pred,
             fine_shape_component,fine_shape_pred, feature_vector) = model(p_slice)
            # print(coarse_slice_pred.shape, coarse_shape_pred.shape, fine_shape_pred.shape)
            # print(p_slice.shape, c_slice.shape, c_shape.shape)
            # print('coarse_shape_pred[:,:,3:], c_shape[:,:,3:]',coarse_shape_pred[:,:,3:].shape, c_shape[:,:,3:].shape)
            # print('coarse_shape_pred[:,:,:3], c_shape[:,:,:3]', coarse_shape_pred[:, :, :3].shape,
            #       c_shape[:, :, :3].shape)
            # real_labels = smooth_labels(torch.ones(params.batch_size, 1).to(params.device), 0.1)
            # fake_labels = smooth_labels(torch.zeros(params.batch_size, 1).to(params.device), 0.1)

            # outputs = D_model(c_slice)

            # d_real_loss = criterion(outputs, real_labels)

            # output_fake_labels = D_model(coarse_slice_pred)
            # d_fake_loss = criterion( output_fake_labels,fake_labels)

            # d_loss = d_real_loss + d_fake_loss

            # d_loss.backward(retain_graph=True)
            # torch.nn.utils.clip_grad_norm_(D_model.parameters(), max_norm=1.0)

            # loss function
            if params.coarse_loss == 'cd':
                loss1 = (cd_loss_L1(rotate_slice_pred, c_slice)
                         + cd_loss_L1(coarse_shape_pred, c_shape[:, :, :3])
                         + cd_loss_L1(coarse_lv_pred, lv_pc[:, :, :3])
                         + cd_loss_L1(coarse_rv_pred, rv_pc[:, :, :3])
                         + cd_loss_L1(coarse_aro_pred, aro_pc[:, :, :3])
                         + cd_loss_L1(coarse_la_pred, la_pc[:, :, :3])
                         + cd_loss_L1(coarse_ra_pred, ra_pc[:, :, :3])
                         + cd_loss_L1(coarse_myo_pred, myo_pc[:, :, :3])
                         + cd_loss_L1(rec_slice_pred, p_slice))
                # + nn.PlaneLoss(coarse_slice_pred, c_slice)
            # elif params.coarse_loss == 'emd':
            #     coarse_c = c[:, :1024, :]
            #     loss1 = emd_loss(coarse_pred, coarse_c)
            else:
                raise ValueError('Not implemented loss {}'.format(params.coarse_loss))
            # print(fine_shape_pred[:,:,3:])
            loss2 = (cd_loss_L1(fine_shape_pred, c_shape[:, :, :3])+
                     cd_loss_L1(fine_shape_component, c_shape[:, :, :3]))
            loss3 = cd_loss_L1(fine_shape_pred,fine_shape_component)

            # print(fine_id.shape, c_shape[:,:,3:].shape)
            # loss_cls = 0.1 * nn.CrossEntropyLoss()(fine_id, c_shape[:,:,3:])

            # g_loss=criterion(output_fake_labels,real_labels)
            loss = loss1 + alpha * loss2  + 0.1 * alpha * loss3  # + alpha * loss_cls

            # back propagation
            loss.backward(retain_graph=True)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # g_loss.backward(retain_graph=True)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # optimizer_d.step()
            optimizer.step()
            # optimizer_g.step()

            # if (i + 1) % step == 0:
            log(log_fd,
                "Training Epoch [{:03d}/{:03d}] - Iteration [{:03d}/{:03d}]: coarse loss = {:.6f}, dense l1 cd = {:.6f}, total loss = {:.6f}"
                .format(epoch, params.epochs, i + 1, len(train_dataloader), loss1.item() * 1e3, loss2.item() * 1e3,
                        loss.item() * 1e3))
                # log(log_fd,
                #     " discrimination loss = {:.6f}"
                #     .format(d_loss.item()))

            train_writer.add_scalar('coarse'
                                    , loss1.item(), train_step)
            train_writer.add_scalar('dense', loss2.item(), train_step)
            train_writer.add_scalar('total', loss.item(), train_step)
            train_step += 1

        lr_schedual.step()

        # evaluation
        model.eval()
        # D_model.eval()
        total_cd_l1_slice, total_cd_l1_shape = 0.0, 0.0
        total_cd_l1_componentrec = 0.0
        with torch.no_grad():
            rand_iter = random.randint(0, len(val_dataloader) - 1)  # for visualization

            for i, data_1 in enumerate(val_dataloader):
                (p_slice, c_slice, c_shape,
                 lv_pc, rv_pc, aro_pc, la_pc,
                 ra_pc, myo_pc, path) = data_1

                p_slice, c_slice, c_shape = p_slice.to(params.device), c_slice.to(params.device), c_shape.to(params.device)
                lv_pc, rv_pc, aro_pc = lv_pc.to(params.device), rv_pc.to(params.device), aro_pc.to(params.device)
                la_pc, ra_pc, myo_pc = la_pc.to(params.device), ra_pc.to(params.device), myo_pc.to(params.device)
            #     (p_slice, c_slice, c_shape,
            #      lv_pc, rv_pc, aro_pc, la_pc,
            #      ra_pc, myo_pc, path) = data_
            # for i, (p_slice, c_slice, c_shape, path) in enumerate(val_dataloader):
            #     p_slice, c_slice, c_shape = p_slice.to(params.device), c_slice.to(params.device), c_shape.to(
            #         params.device)
            #     coarse_slice_pred, rotate_slice_pred, coarse_shape_pred, fine_shape_pred, fine_id = model(p_slice)
                (coarse_slice_pred,
                 rotate_slice_pred,
                 coarse_shape_pred,
                 coarse_lv_pred, coarse_rv_pred, coarse_aro_pred,
                 coarse_la_pred, coarse_ra_pred, coarse_myo_pred,
                 fine_shape_component,fine_shape_pred, feature_vector) = model(p_slice)

                total_cd_l1_slice += l1_cd(rotate_slice_pred, c_slice).item()
                total_cd_l1_shape += l1_cd(fine_shape_pred, c_shape).item()
                total_cd_l1_componentrec += l1_cd(fine_shape_pred, fine_shape_component).item()

                # save into image
                if rand_iter == i:
                    index = random.randint(0, fine_shape_pred.shape[0] - 1)
                    plot_pcd_one_view(os.path.join(epochs_dir, 'epoch_{:03d}.png'.format(epoch)),
                                      [p_slice[index].detach().cpu().numpy(),
                                       coarse_slice_pred[index].detach().cpu().numpy(),
                                       rotate_slice_pred[index].detach().cpu().numpy(),
                                       c_slice[index].detach().cpu().numpy(),
                                       coarse_shape_pred[index].detach().cpu().numpy(),
                                       coarse_lv_pred[index].detach().cpu().numpy(),
                                       coarse_rv_pred[index].detach().cpu().numpy(),
                                       coarse_aro_pred[index].detach().cpu().numpy(),
                                       coarse_la_pred[index].detach().cpu().numpy(),
                                       coarse_ra_pred[index].detach().cpu().numpy(),
                                       coarse_myo_pred[index].detach().cpu().numpy(),

                                       fine_shape_component[index].detach().cpu().numpy(),
                                       fine_shape_pred[index].detach().cpu().numpy(),
                                       c_shape[index].detach().cpu().numpy()],
                                      ['Input Slice', 'Coarse Slice', 'Rotate Slice', 'Ground Truth Slice',
                                       'Coarse Shape','lv','rv','aro','la','ra','myo','Dense component',
                                       'Dense Shape', 'Ground Truth Shape'], xlim=(-0.35, 0.35),
                                      ylim=(-0.35, 0.35), zlim=(-0.35, 0.35))

            total_cd_l1_slice /= len(val_dataset)
            total_cd_l1_shape /= len(val_dataset)
            total_cd_l1_componentrec /= len(val_dataset)

            epochL.append(epoch)
            slice_lossL.append(total_cd_l1_slice)
            shape_lossL.append(total_cd_l1_shape)
            componentrec_lossL.append(total_cd_l1_componentrec)
            df_loss=pd.DataFrame()
            df_loss['epoch']=epochL
            df_loss['slice loss']=slice_lossL
            df_loss['shape loss']=shape_lossL
            df_loss['component rec loss']=componentrec_lossL
            df_loss.to_csv(loss_csv_path)

            val_writer.add_scalar('l1_cd', total_cd_l1_slice, total_cd_l1_shape, val_step)
            val_step += 1

            log(log_fd,
                "Validate Epoch [{:03d}/{:03d}]: L1 Chamfer Distance = {:.6f}, L1 Chamfer Distance Shape = {:.6f}".format(
                    epoch, params.epochs,
                    total_cd_l1_shape * 1e3, total_cd_l1_slice * 1e3))

        if total_cd_l1_slice < best_cd_l1:
            best_epoch_l1 = epoch
            best_cd_l1 = total_cd_l1_slice
            torch.save(model.state_dict(), os.path.join(ckpt_dir, 'best_l1_cd.pth'))
            # torch.save(D_model.state_dict(), os.path.join(ckpt_dir, 'best_l1_cd_d.pth'))

        data_to_write = {
            "train_step": train_step,
            "val_step": val_step,
            "best_epoch_l1": best_epoch_l1,
            "epoch": epoch,
            "best_cd_l1": best_cd_l1,
            "alpha": alpha
        }

        with open(ckpt_dir + '/output.yaml', 'w') as file:
            yaml.dump(data_to_write, file, default_flow_style=False)


    log(log_fd, 'Best l1 cd model in epoch {}, the minimum l1 cd is {}'.format(best_epoch_l1, best_cd_l1 * 1e3))
    log_fd.close()


TYPE = 'Slice-nosig-0929'
# TYPE='Rec-Slice-6sigcomponent-compare-0929'
# TYPE='PCN-Slice-240827'
if __name__ == '__main__':
    parser = argparse.ArgumentParser('PCN')
    parser.add_argument('--exp_name', type=str, default=TYPE, help='Tag of experiment')
    parser.add_argument('--log_dir', type=str, default='log', help='Logger directory')
    parser.add_argument('--ckpt_path', type=str,
                        default='./log/' + TYPE + '/all/checkpoints/best_l1_cd.pth',
                        help='The path of pretrained model')
    # parser.add_argument('--ckpt_path', type=str, default=None, help='ckpt of pretrained model')
    parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
    parser.add_argument('--category', type=str, default='all', help='Category of point clouds')
    parser.add_argument('--epochs', type=int, default=300, help='Epochs of training')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for data loader')
    parser.add_argument('--coarse_loss', type=str, default='cd', help='loss function for coarse point cloud')
    parser.add_argument('--num_workers', type=int, default=1, help='num_workers for data loader')
    parser.add_argument('--device', type=str, default='cuda:0', help='device for training')
    parser.add_argument('--log_frequency', type=int, default=10, help='Logger frequency in every epoch')
    parser.add_argument('--save_frequency', type=int, default=10, help='Model saving frequency')
    params = parser.parse_args()

    train(params)
