'''
@Author: Wenhao Ding
@Email: wenhaod@andrew.cmu.edu
@Date: 2020-07-09 13:51:09
LastEditTime: 2021-05-30 20:28:51
@Description:
'''

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from torch.utils.data import DataLoader

from dataloader import SemanticKittiDataset
from model import Pointnet2MSG
from utils import CUDA, CPU, save_ply, COLOR
from loss import CrosEntropyLoss, DiceLoss


class Trainer(object):
    def __init__(self, args):
        print(COLOR.GREEN+'Mode:', args.mode)
        print('Continue Training:', args.continue_training)
        print(COLOR.WHITE+'')

        self.args = args
        self.continue_training = args.continue_training
        self.workers = args.workers
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epochs = args.max_epochs
        self.save_interval = args.save_interval
        self.best_eval_iou = 0.0
        self.number_class = 1
        self.FG_THRESH = 0.3
        self.npoints = args.npoints

        self.kitti_path = args.kitti_path
        self.model_path = args.model_path
        self.log_path = args.log_path

        # NOTE: dont use self. variable inside of this function
        def lr_lbmd(cur_epoch):
            cur_decay = 1
            for decay_step in args.decay_step_list:
                if cur_epoch >= decay_step:
                    cur_decay = cur_decay * args.lr_decay
            return max(cur_decay, args.lr_clip / args.lr)

        # define model and optimizer
        self.model = CUDA(Pointnet2MSG(input_channels=0, number_class=self.number_class))
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=args.weight_decay)
        self.scheduler = LambdaLR(self.optimizer, lr_lbmd)
        #self.loss_func = CrosEntropyLoss(number_class=self.number_class)
        self.loss_func = DiceLoss()
        self.training = args.mode == 'train' # used for saving model

    @staticmethod
    def fg_iou(pred_class, cls_labels, number_class):
        # calculate the IOU of each class
        # shape - [B*N]
        pred_class = pred_class.view(-1)
        cls_labels = cls_labels.view(-1, 1)[:, 0]

        iou_list = []
        for c_i in range(number_class):
            # skip the background id
            if c_i == 0:
                continue
            fg_mask = cls_labels == c_i
            correct = ((pred_class == cls_labels) & fg_mask).float().sum()
            union = fg_mask.sum().float() + (pred_class == c_i).sum().float() - correct
            iou = correct / torch.clamp(union, min=1.0)
            iou_list.append(iou)
        return iou_list

    @staticmethod
    def fg_iou_single(pred_class, cls_labels):
        # calculate the IOU of vehicle
        # mIoU: the IoU score is calculated for each class separately and then averaged over all classes
        fg_mask = cls_labels > 0
        correct = ((pred_class == cls_labels) & fg_mask).float().sum()
        union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct
        iou = correct / torch.clamp(union, min=1.0)
        return iou
        
    @staticmethod
    def save_pc(name, xyz, label):
        # save the whole batch
        for p_i in range(xyz.shape[0]):
            one_xyz = CPU(xyz[p_i])
            one_label = CPU(label[p_i])

            rgb = np.zeros_like(one_xyz)
            for i in range(one_label.shape[0]):
                if one_label[i] == 0:
                    rgb[i] = [255, 255, 255]
                elif one_label[i] == 1:
                    rgb[i] = [255, 0, 0]
                elif one_label[i] == 2:
                    rgb[i] = [0, 255, 0]
            xyzrgb = np.concatenate([one_xyz, rgb], axis=1)

            filename = './log/samples/' + str(p_i) + '_' + name
            save_ply(filename, xyzrgb)

    def train_and_eval(self, name='kitti'):
        # load validation data and training data
        self.eval_set = SemanticKittiDataset(args=self.args, split='valid')
        self.train_set = SemanticKittiDataset(args=self.args, split='train')

        self.eval_loader = DataLoader(
            self.eval_set, 
            batch_size=self.batch_size, 
            shuffle=False, 
            pin_memory=True, 
            num_workers=self.workers)
        self.train_loader = DataLoader(
            self.train_set, 
            batch_size=self.batch_size, 
            shuffle=True, 
            pin_memory=True, 
            num_workers=self.workers)

        if self.continue_training:
            self.best_eval_iou = self.model.load_model(self.model_path)

        pbar_itr = tqdm(total=len(self.train_loader), desc='Start to train')
        for epoch in range(self.max_epochs):
            # train
            self.train_one_epoch(epoch, pbar_itr)
            self.scheduler.step()

            # evaluate
            if (epoch+1) % self.save_interval == 0:
                self.eval_one_epoch(pbar_itr)

    def train_one_epoch(self, epoch, pbar_itr=None):
        self.model.train()
        avg_loss = []
        avg_iou_veh = []
        pbar_itr.set_description('[Train]')
        pbar_itr.reset(len(self.train_loader))
        for pts_input, cls_labels in self.train_loader:
            self.optimizer.zero_grad()
            #pts_input, cls_labels = batch['pts_input'], batch['cls_labels']
            pts_input = CUDA(pts_input).float()
            cls_labels = CUDA(cls_labels).long().view(-1)

            pred_cls = self.model(pts_input).view(-1)
            loss = self.loss_func(pred_cls, cls_labels)
            loss.backward()
            clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            # calculate the foreground iou
            #pred_class = torch.argmax(pred_cls, dim=2)  # for cross-entropy loss
            pred_class = (torch.sigmoid(pred_cls) > self.FG_THRESH).long()
            iou_list = self.fg_iou_single(pred_class, cls_labels)
            
            # statistic
            pbar_itr.update(1)
            pbar_itr.set_description((COLOR.GREEN+'Epoch [%d/%d]'+COLOR.WHITE+': loss = %.5f, iou_veh = %.5f')
                % (epoch, self.max_epochs, loss.item(), iou_list.item()))
            avg_loss.append(loss.item())
            avg_iou_veh.append(iou_list.item())

        # NOTE: save samples, the number depends on the last batch size
        self.save_pc('train_gt.ply', pts_input, cls_labels.view(pts_input.shape[0], self.npoints))
        self.save_pc('train_predict.ply', pts_input, pred_class.view(pts_input.shape[0], self.npoints))
        pbar_itr.write('Epoch [%d/%d]: avg_loss = %.5f, avg_iou_veh = %.3f' 
            % (epoch, self.max_epochs, np.mean(avg_loss), np.mean(avg_iou_veh)))

    def eval_one_epoch(self, pbar_itr=None):
        self.model.eval()
        pbar_itr.reset(len(self.eval_loader))
        pbar_itr.set_description('[Evaluation]')
        avg_iou_veh = []
        with torch.no_grad():
            for pts_input, cls_labels in self.eval_loader:
                pts_input = CUDA(pts_input).float()
                cls_labels = CUDA(cls_labels).long().view(-1)

                pred_cls = self.model(pts_input).view(-1)
                pred_class = (torch.sigmoid(pred_cls) > self.FG_THRESH).long()
                iou_list = self.fg_iou_single(pred_class, cls_labels)
                
                avg_iou_veh.append(iou_list.item())
                pbar_itr.update(1)

        # should be 0.65 ~ 0.70 for kitti evaluation dataset
        avg_iou_veh = np.mean(avg_iou_veh)
        if self.best_eval_iou < avg_iou_veh:
            self.best_eval_iou = avg_iou_veh
            if self.training:
                self.model.save_model(self.model_path, self.best_eval_iou)
        
        pbar_itr.write((COLOR.GREEN+'Evaluation:'+COLOR.WHITE+' avg_iou_veh = %.6f, best_avg_iou_veh = %.6f') % (avg_iou_veh, self.best_eval_iou))
        self.save_pc('eval_gt.ply', pts_input, cls_labels.view(pts_input.shape[0], self.npoints))
        self.save_pc('eval_predict.ply', pts_input, pred_class.view(pts_input.shape[0], self.npoints))
        return avg_iou_veh

    def evaluation(self, name='kitti'):
        # load validation data
        self.eval_set = SemanticKittiDataset(args=self.args, split='valid')

        self.eval_loader = DataLoader(
            self.eval_set, 
            batch_size=self.batch_size, 
            shuffle=False, 
            pin_memory=True, 
            num_workers=self.workers)
        self.model.load_model(self.model_path)
        pbar_itr = tqdm(total=len(self.eval_loader), desc='[Evaluation]')
        self.eval_one_epoch(pbar_itr)
