import os
import time
import tqdm
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
from pcdet.utils import common_utils, commu_utils


class InfoNCELoss(nn.Module):
    def __init__(self, loss_cfg):
        super().__init__()
        self.T = loss_cfg['TEMPERATURE']
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, pos_feature_1, pos_feature_2, neg_feature):
        N, C = pos_feature_1.shape
        l_pos = torch.bmm(pos_feature_1.view(N, 1, C), pos_feature_2.view(N, C, 1)).view(N, 1) # Nx1
        if neg_feature.shape == pos_feature_1.shape: # paired
            l_neg = torch.bmm(pos_feature_1.view(N, 1, C), neg_feature.view(N, C, 1)).view(N, 1) # Nx1 -> need to support NxK
        else:
            l_neg = torch.bmm(pos_feature_1.view(N, 1, C), neg_feature.view(N, C, -1)).view(N, -1)
        logits = torch.cat([l_pos, l_neg], dim=1).view(N, -1)
        labels = logits.new_zeros(logits.shape[0]).long()
        loss = self.criterion(logits / self.T, labels)
        return loss

def split_instances_by_distance(boxes, pred_inds, loss_cfg):
    distance_split_labels = np.zeros(boxes.shape[0])
    boxes = boxes.detach().cpu().numpy()
    distance = np.sqrt(np.sum(boxes[:, 0:3] * boxes[:, 0:3], axis=1))
    # TODO add distance_split: [0, 30, 50]
    split_distance = loss_cfg['SPLIT_DISTANCE']
    split_distance_dict = {}
    for i in range(len(split_distance)):
        if i < len(split_distance) - 1:
            dis_mask = (distance >= split_distance[i]) & (distance < split_distance[i+1])
            split_distance_dict.update({
                'distance_label_' + str(i): np.array(pred_inds)[dis_mask].tolist()
            })
        else:
            dis_mask = distance > split_distance[i]
            split_distance_dict.update({
                'distance_label_' + str(i): np.array(pred_inds)[dis_mask].tolist()
            })
        distance_split_labels[dis_mask] = i
    return distance_split_labels, split_distance_dict

def get_contrastive_loss(pred_dict, loss_cfg):
    '''
    get contrastive loss
    '''
    pos_feature_1, pos_feature_2, neg_feature = sample_contrastive_pairs(pred_dict, loss_cfg)
    if pos_feature_1 == None:
        assert pos_feature_2 == None and neg_feature == None
        loss = 0.
        return loss
    loss_type = loss_cfg['NAME']
    if loss_type == 'InfoNCELoss':
        criterion = InfoNCELoss(loss_cfg)
        loss = criterion(pos_feature_1, pos_feature_2, neg_feature)
    else:
        raise NotImplementedError
    return loss

def sample_contrastive_pairs(pred_dict, loss_cfg):
    '''
    sample range-aware positive and nagetive pair
    need to do:
        1. range-aware positive and negative
        2. let positive instance be a unified representation (prototype)
        3. in-scene copy and paste
    '''
    # input data_dict: roi feature and predicted bounding boxes and corresponding iou confidence
    # spilt instances by distance and iou thresshold
    # construct positive pair and negtive pair

    batch_size = len(pred_dict)
    positive_feature_1 = None
    positive_feature_2 = None
    negative_feature = None
    pred_boxes_ori_list, pred_scores_ori_list, pred_labels_ori_list, roi_feature_ori_list = [], [], [], []
    for index in range(batch_size):
        pred_boxes_ori_list.append(pred_dict[index]['pred_boxes'])
        pred_scores_ori_list.append(pred_dict[index]['pred_scores'])
        pred_labels_ori_list.append(pred_dict[index]['pred_labels'])
        roi_feature_ori_list.append(pred_dict[index]['select_roi'])

    pred_boxes_ori = torch.cat(pred_boxes_ori_list, dim=0)
    pred_scores_ori = torch.cat(pred_scores_ori_list, dim=0)
    pred_labels_ori = torch.cat(pred_labels_ori_list, dim=0)
    roi_feature_ori = torch.cat(roi_feature_ori_list, dim=0)

    for i, cur_class in enumerate(loss_cfg.CLASS):
        cur_loss_cfg = loss_cfg[cur_class]
        cls_mask = pred_labels_ori == (i + 1)
        pred_boxes = pred_boxes_ori[cls_mask]
        pred_scores = pred_scores_ori[cls_mask]
        pred_labels = pred_labels_ori[cls_mask]
        roi_feature = roi_feature_ori[cls_mask]

        pred_inds = list(range(pred_labels.size(0)))
        distance_split_labels, distance_split_dict = split_instances_by_distance(pred_boxes, pred_inds, loss_cfg)
        
        # sample positive pairs
        positive_mask = (pred_scores > cur_loss_cfg['POS_CFG']['POS_THR']).detach().cpu().numpy()
        if cur_loss_cfg['POS_CFG']['DISTANCE_AWARE']:
            # positive_distance_labels = distance_split_labels[positive_mask]
            # distance2index = positive_distance_labels
            pos_inds = np.array(pred_inds)[positive_mask].tolist()
            if len(pos_inds) == 0:
                continue
            pos_pair_inds = np.random.randint(0, len(pos_inds), (1, cur_loss_cfg['POS_CFG']['POS_NUM'])).tolist()
            pos_feature_1 = roi_feature[pos_inds][pos_pair_inds[0]]
            distance_split_dict_pos = {}
            for i in range(len(loss_cfg['SPLIT_DISTANCE'])):
                cur_distance_inds = distance_split_dict['distance_label_'+str(i)]
                cur_pos_mask = (pred_scores[cur_distance_inds] > cur_loss_cfg['POS_CFG']['POS_THR']).cpu().tolist()
                cur_distance_label = np.array(cur_distance_inds)[cur_pos_mask].tolist()
                cur_distance_label = [cur_distance_label] if not isinstance(cur_distance_label, list) else cur_distance_label
                distance_split_dict_pos.update({
                    'distance_label' + str(i): cur_distance_label
                })

            pos_distance_label = distance_split_labels[pos_inds][pos_pair_inds[0]]
            pos_feature_2 = pos_feature_1.new_zeros(pos_feature_1.shape)
            for i in range(len(loss_cfg['SPLIT_DISTANCE'])):
                cur_distance_pos_mask = pos_distance_label == i
                cur_distance_num_pos = sum(cur_distance_pos_mask)
                cur_pos_inds = list(range(len(pos_distance_label)))
                cur_pos_inds_2 = distance_split_dict_pos['distance_label'+str(i)]
                cur_distance_num_pos_2 = len(cur_pos_inds_2)
                if cur_distance_num_pos == 0:
                    continue
                pos_sample_inds = np.random.randint(0, cur_distance_num_pos_2, cur_distance_num_pos).tolist()
                pos_feature_2[np.array(cur_pos_inds)[cur_distance_pos_mask].tolist()] = roi_feature[cur_pos_inds_2][pos_sample_inds]

        else:
            pos_inds = np.array(pred_inds)[positive_mask].tolist()
            if len(pos_inds) == 0:
                continue
            pos_pair_inds = np.random.randint(0, len(pos_inds), (2, cur_loss_cfg['POS_CFG']['POS_NUM'])).tolist()
            pos_feature_1 = roi_feature[pos_inds][pos_pair_inds[0]]
            pos_feature_2 = roi_feature[pos_inds][pos_pair_inds[1]]
            # sort positive feature by distance
        # sample negative pairs
        negative_mask = (pred_scores < cur_loss_cfg['NEG_CFG']['NEG_THR']).cpu()
        if sum(negative_mask) == 0:
            continue
        if cur_loss_cfg['NEG_CFG']['DISTANCE_AWARE']:
            distance_split_dict_neg = {}
            for i in range(len(loss_cfg['SPLIT_DISTANCE'])):
                cur_distance_inds = distance_split_dict['distance_label_'+str(i)]
                cur_neg_mask = (pred_scores[cur_distance_inds] < cur_loss_cfg['NEG_CFG']['NEG_THR']).cpu().tolist()
                cur_distance_label = np.array(cur_distance_inds)[cur_neg_mask].tolist()
                cur_distance_label = [cur_distance_label] if not isinstance(cur_distance_label, list) else cur_distance_label
                distance_split_dict_neg.update({
                    'distance_label' + str(i): cur_distance_label
                })

            pos_distance_label = distance_split_labels[pos_pair_inds[0]]
            neg_feature = pos_feature_1.new_zeros(pos_feature_1.shape)
            for i in range(len(loss_cfg['SPLIT_DISTANCE'])):
                cur_distance_pos_mask = pos_distance_label == i
                cur_distance_num_pos = sum(cur_distance_pos_mask)
                cur_pos_inds = list(range(len(pos_distance_label)))
                cur_neg_inds = distance_split_dict_neg['distance_label'+str(i)]
                cur_distance_num_neg = len(cur_neg_inds)
                if cur_distance_num_neg == 0:
                    continue
                neg_sample_inds = np.random.randint(0, cur_distance_num_neg, cur_distance_num_pos).tolist()
                neg_feature[np.array(cur_pos_inds)[cur_distance_pos_mask].tolist()] = roi_feature[cur_neg_inds][neg_sample_inds]
        else:
            neg_inds = np.array(pred_inds)[negative_mask].tolist()
            # enable multiple neg feature
            neg_pair_inds = np.random.randint(0, len(neg_inds), (1, cur_loss_cfg['POS_CFG']['POS_NUM']*cur_loss_cfg['NEG_CFG']['NEG_NUM_PER_ANCHOR'])).tolist()
            neg_feature = roi_feature[neg_inds][neg_pair_inds[0]].view(cur_loss_cfg['POS_CFG']['POS_NUM'], cur_loss_cfg['NEG_CFG']['NEG_NUM_PER_ANCHOR'], -1)

        if positive_feature_1 == None:
            assert positive_feature_2 == None and negative_feature == None
            positive_feature_1, positive_feature_2, negative_feature = pos_feature_1, pos_feature_2, neg_feature
        else:
            positive_feature_1 = torch.cat([positive_feature_1, pos_feature_1], dim=0)
            positive_feature_2 = torch.cat([positive_feature_2, pos_feature_2], dim=0)
            negative_feature = torch.cat([negative_feature, neg_feature], dim=0)

    return positive_feature_1, positive_feature_2, negative_feature


def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
                    rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False):
    if total_it_each_epoch == len(train_loader):
        dataloader_iter = iter(train_loader)

    if rank == 0:
        pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)
        data_time = common_utils.AverageMeter()
        batch_time = common_utils.AverageMeter()
        forward_time = common_utils.AverageMeter()

    for cur_it in range(total_it_each_epoch):
        end = time.time()
        try:
            batch = next(dataloader_iter)
        except StopIteration:
            dataloader_iter = iter(train_loader)
            batch = next(dataloader_iter)
            print('new iters')
        
        data_timer = time.time()
        cur_data_time = data_timer - end

        lr_scheduler.step(accumulated_iter)

        try:
            cur_lr = float(optimizer.lr)
        except:
            cur_lr = optimizer.param_groups[0]['lr']

        if tb_log is not None:
            tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)

        model.train()
        optimizer.zero_grad()

        if optim_cfg.get('USE_CONTRASTIVE_LOSS', False):
            loss_det, pred_dict, tb_dict, disp_dict = model_func(model, batch)
            loss_contrast = get_contrastive_loss(pred_dict, optim_cfg.CONTRASTIVE_LOSS_CFG) # need to add
            loss = loss_det + optim_cfg.CONTRASTIVE_LOSS_CFG.RATIO * loss_contrast
        else:
            loss, tb_dict, disp_dict = model_func(model, batch)
        
        forward_timer = time.time()
        cur_forward_time = forward_timer - data_timer

        loss.backward()
        clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP)
        optimizer.step()

        accumulated_iter += 1

        cur_batch_time = time.time() - end
        # average reduce
        avg_data_time = commu_utils.average_reduce_value(cur_data_time)
        avg_forward_time = commu_utils.average_reduce_value(cur_forward_time)
        avg_batch_time = commu_utils.average_reduce_value(cur_batch_time)

        # log to console and tensorboard
        if rank == 0:
            data_time.update(avg_data_time)
            forward_time.update(avg_forward_time)
            batch_time.update(avg_batch_time)
            disp_dict.update({
                'loss': loss.item(), 'loss_det': loss_det.item(), 'loss_contrast': loss_contrast.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})',
                'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})'
            })

            pbar.update()
            pbar.set_postfix(dict(total_it=accumulated_iter))
            tbar.set_postfix(disp_dict)
            tbar.refresh()

            if tb_log is not None:
                tb_log.add_scalar('train/loss', loss, accumulated_iter)
                tb_log.add_scalar('train/loss_det', loss_det, accumulated_iter)
                tb_log.add_scalar('train/loss_contrast', loss_contrast, accumulated_iter)
                tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)
                for key, val in tb_dict.items():
                    tb_log.add_scalar('train/' + key, val, accumulated_iter)
    if rank == 0:
        pbar.close()
    return accumulated_iter

# TODO: dynamically update the PS100K
def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_cfg,
                start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, 
                source_sampler=None, lr_warmup_scheduler=None, ckpt_save_interval=1,
                max_ckpt_save_num=50, merge_all_iters_to_one_epoch=False):
    accumulated_iter = start_iter
    with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar:
        total_it_each_epoch = len(train_loader)
        if merge_all_iters_to_one_epoch:
            assert hasattr(train_loader.dataset, 'merge_all_iters_to_one_epoch')
            train_loader.dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs)
            total_it_each_epoch = len(train_loader) // max(total_epochs, 1)

        dataloader_iter = iter(train_loader)
        for cur_epoch in tbar:
            if source_sampler is not None:
                source_sampler.set_epoch(cur_epoch)

            # train one epoch
            if lr_warmup_scheduler is not None and cur_epoch < optim_cfg.WARMUP_EPOCH:
                cur_scheduler = lr_warmup_scheduler
            else:
                cur_scheduler = lr_scheduler
            accumulated_iter = train_one_epoch(
                model, optimizer, train_loader, model_func,
                lr_scheduler=cur_scheduler,
                accumulated_iter=accumulated_iter, optim_cfg=optim_cfg,
                rank=rank, tbar=tbar, tb_log=tb_log,
                leave_pbar=(cur_epoch + 1 == total_epochs),
                total_it_each_epoch=total_it_each_epoch,
                dataloader_iter=dataloader_iter
            )
            
            if optim_cfg.get('PROGRESSIVE_TRAINING', False):
                raise NotImplementedError
            # save trained model
            trained_epoch = cur_epoch + 1
            if trained_epoch % ckpt_save_interval == 0 and rank == 0:

                ckpt_list = glob.glob(str(ckpt_save_dir / 'checkpoint_epoch_*.pth'))
                ckpt_list.sort(key=os.path.getmtime)

                if ckpt_list.__len__() >= max_ckpt_save_num:
                    for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
                        os.remove(ckpt_list[cur_file_idx])

                ckpt_name = ckpt_save_dir / ('checkpoint_epoch_%d' % trained_epoch)
                save_checkpoint(
                    checkpoint_state(model, optimizer, trained_epoch, accumulated_iter), filename=ckpt_name,
                )


def model_state_to_cpu(model_state):
    model_state_cpu = type(model_state)()  # ordered dict
    for key, val in model_state.items():
        model_state_cpu[key] = val.cpu()
    return model_state_cpu


def checkpoint_state(model=None, optimizer=None, epoch=None, it=None):
    optim_state = optimizer.state_dict() if optimizer is not None else None
    if model is not None:
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model_state = model_state_to_cpu(model.module.state_dict())
        else:
            model_state = model.state_dict()
    else:
        model_state = None

    try:
        import pcdet
        version = 'pcdet+' + pcdet.__version__
    except:
        version = 'none'

    return {'epoch': epoch, 'it': it, 'model_state': model_state, 'optimizer_state': optim_state, 'version': version}


def save_checkpoint(state, filename='checkpoint'):
    if False and 'optimizer_state' in state:
        optimizer_state = state['optimizer_state']
        state.pop('optimizer_state', None)
        optimizer_filename = '{}_optim.pth'.format(filename)
        torch.save({'optimizer_state': optimizer_state}, optimizer_filename)

    filename = '{}.pth'.format(filename)
    torch.save(state, filename)