import glob
import os

import torch
import tqdm
from torch.nn.utils import clip_grad_norm_
from pcdet.config import cfg
import wandb

def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
                    rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False):
    if total_it_each_epoch == len(train_loader):
        dataloader_iter = iter(train_loader)

    if rank == 0:
        pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)

    for cur_it in range(total_it_each_epoch):
        try:
            batch = next(dataloader_iter)
        except StopIteration:
            dataloader_iter = iter(train_loader)
            batch = next(dataloader_iter)
            print('new iters')

        lr_scheduler.step(accumulated_iter)

        try:
            cur_lr = float(optimizer.lr)
        except:
            cur_lr = optimizer.param_groups[0]['lr']

        if tb_log is not None:
            tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)

        model.train()
        optimizer.zero_grad()

        loss, tb_dict, disp_dict, _ = model_func(model, batch)

        if cfg.get('Triplet', None):
            import torch.nn.functional as F
            batch_dict_source = _
            # 提取特征和标签
            source_features = batch_dict_source['roi_head_features']  # (B, N_s, D)
            source_labels = batch_dict_source['roi_labels']           # (B, N_s)
            # target_features = batch_dict_target['roi_head_features']  # (B, N_t, D)
            # target_labels = batch_dict_target['roi_labels']           # (B, N_t)

            batch_size, num_target_rois, feat_dim = source_features.shape
            margin = 1.0
            triplet_loss = 0.0
            valid_count = 0

            lambda_loss_triplet = cfg.Triplet.get('lambda_loss_triplet', 0.1)
            lambda_loss_inter = cfg.Triplet.get('lambda_loss_inter', 1.0)
            lambda_loss_intra = cfg.Triplet.get('lambda_loss_intra', 1.0)

            if source_features.shape[0] == source_features.shape[0]:
                for b in range(batch_size):
                    src_feat_b = source_features[b]  # (N_s, D)
                    src_label_b = source_labels[b]    # (N_s,)
                    # tgt_feat_b = target_features[b]   # (N_t, D)
                    # tgt_label_b = target_labels[b]    # (N_t,)

                    # num_target = tgt_feat_b.shape[0]
                    num_source = src_feat_b.shape[0]

                    # 处理目标域锚点
                    # for i in range(num_target):
                        # anchor_feat = tgt_feat_b[i]
                        # anchor_label = tgt_label_b[i].item()

                        # 跨域三元组 (目标锚点, 源正样本, 目标负样本)
                        # src_same_mask = (src_label_b == anchor_label)
                        # if src_same_mask.sum() > 0:
                        #     src_pos = src_feat_b[src_same_mask][torch.argmax(torch.norm(src_feat_b[src_same_mask] - anchor_feat, dim=1))]
                        #     tgt_diff_mask = (tgt_label_b != anchor_label)
                        #     if tgt_diff_mask.sum() > 0:
                        #         tgt_neg = tgt_feat_b[tgt_diff_mask][torch.argmin(torch.norm(tgt_feat_b[tgt_diff_mask] - anchor_feat, dim=1))]
                        #         loss = F.relu(torch.norm(anchor_feat - src_pos) - torch.norm(anchor_feat - tgt_neg) + margin)
                        #         triplet_loss += lambda_loss_inter * loss
                        #         valid_count += 1

                        # # 目标域内三元组 (目标锚点, 目标正样本, 目标负样本)
                        # tgt_same_mask = (tgt_label_b == anchor_label)
                        # tgt_same_mask[i] = False
                        # if tgt_same_mask.sum() > 0:
                        #     tgt_pos = tgt_feat_b[tgt_same_mask][torch.argmax(torch.norm(tgt_feat_b[tgt_same_mask] - anchor_feat, dim=1))]
                        #     tgt_diff_mask = (tgt_label_b != anchor_label)
                        #     if tgt_diff_mask.sum() > 0:
                        #         tgt_neg = tgt_feat_b[tgt_diff_mask][torch.argmin(torch.norm(tgt_feat_b[tgt_diff_mask] - anchor_feat, dim=1))]
                        #         loss = F.relu(torch.norm(anchor_feat - tgt_pos) - torch.norm(anchor_feat - tgt_neg) + margin)
                        #         triplet_loss += lambda_loss_intra * loss
                        #         valid_count += 1

                    # 处理源域锚点
                    for i in range(num_source):
                        anchor_feat = src_feat_b[i]
                        anchor_label = src_label_b[i].item()

                        # 源域内三元组 (源锚点, 源正样本, 源负样本)
                        src_same_mask = (src_label_b == anchor_label)
                        src_same_mask[i] = False
                        if src_same_mask.sum() > 0:
                            src_pos = src_feat_b[src_same_mask][torch.argmax(torch.norm(src_feat_b[src_same_mask] - anchor_feat, dim=1))]
                            src_diff_mask = (src_label_b != anchor_label)
                            if src_diff_mask.sum() > 0:
                                src_neg = src_feat_b[src_diff_mask][torch.argmin(torch.norm(src_feat_b[src_diff_mask] - anchor_feat, dim=1))]
                                tmp_loss = F.relu(torch.norm(anchor_feat - src_pos) - torch.norm(anchor_feat - src_neg) + margin)
                                triplet_loss += lambda_loss_intra * tmp_loss
                                valid_count += 1

                        # 跨域三元组 (源锚点, 源正样本, 目标负样本)
                        # src_same_mask = (src_label_b == anchor_label)
                        # if src_same_mask.sum() > 0:
                        #     src_pos = src_feat_b[src_same_mask][torch.argmax(torch.norm(src_feat_b[src_same_mask] - anchor_feat, dim=1))]
                        #     tgt_diff_mask = (tgt_label_b != anchor_label)
                        #     if tgt_diff_mask.sum() > 0:
                        #         tgt_neg = tgt_feat_b[tgt_diff_mask][torch.argmin(torch.norm(tgt_feat_b[tgt_diff_mask] - anchor_feat, dim=1))]
                        #         loss = F.relu(torch.norm(anchor_feat - src_pos) - torch.norm(anchor_feat - tgt_neg) + margin)
                        #         triplet_loss += lambda_loss_inter * loss
                        #         valid_count += 1

                # 归一化损失
                if valid_count > 0:
                    triplet_loss /= valid_count
                else:
                    triplet_loss = 0.0

                loss += lambda_loss_triplet * triplet_loss

        loss.backward()
        clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP)
        optimizer.step()

        accumulated_iter += 1
        disp_dict.update({'loss': loss.item(), 'lr': cur_lr})

        # log to console and tensorboard
        if rank == 0:
            pbar.update()
            pbar.set_postfix(dict(total_it=accumulated_iter))
            tbar.set_postfix(disp_dict)
            tbar.refresh()

            if tb_log is not None:
                tb_log.add_scalar('train/loss', loss, accumulated_iter)
                tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)
                for key, val in tb_dict.items():
                    tb_log.add_scalar('train/' + key, val, accumulated_iter)
                wandb.log({'train/loss': loss}, step=accumulated_iter)
                wandb.log({'meta_data/learning_rate': cur_lr}, step=accumulated_iter)
                for key, val in tb_dict.items():
                    tb_log.add_scalar('train/' + key, val, accumulated_iter)
                    wandb.log({'train/' + key: val}, step=accumulated_iter)
    if rank == 0:
        pbar.close()
    return accumulated_iter


def train_model(model, optimizer, train_loader, target_loader, model_func, lr_scheduler, optim_cfg,
                start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, ps_label_dir,
                source_sampler=None, target_sampler=None, lr_warmup_scheduler=None, ckpt_save_interval=1,
                max_ckpt_save_num=50, merge_all_iters_to_one_epoch=False, logger=None, ema_model=None,
                source_loader_detect=None, source_sampler_detect=None, source_model=None, dist=None, pretrained=None):
    accumulated_iter = start_iter
    with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar:
        total_it_each_epoch = len(train_loader)
        if merge_all_iters_to_one_epoch:
            assert hasattr(train_loader.dataset, 'merge_all_iters_to_one_epoch')
            train_loader.dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs)
            total_it_each_epoch = len(train_loader) // max(total_epochs, 1)

        dataloader_iter = iter(train_loader)
        for cur_epoch in tbar:
            if source_sampler is not None:
                source_sampler.set_epoch(cur_epoch)

            # train one epoch
            if lr_warmup_scheduler is not None and cur_epoch < optim_cfg.WARMUP_EPOCH:
                cur_scheduler = lr_warmup_scheduler
            else:
                cur_scheduler = lr_scheduler
            accumulated_iter = train_one_epoch(
                model, optimizer, train_loader, model_func,
                lr_scheduler=cur_scheduler,
                accumulated_iter=accumulated_iter, optim_cfg=optim_cfg,
                rank=rank, tbar=tbar, tb_log=tb_log,
                leave_pbar=(cur_epoch + 1 == total_epochs),
                total_it_each_epoch=total_it_each_epoch,
                dataloader_iter=dataloader_iter
            )

            # save trained model
            trained_epoch = cur_epoch + 1
            if trained_epoch % ckpt_save_interval == 0 and rank == 0:

                ckpt_list = glob.glob(str(ckpt_save_dir / 'checkpoint_epoch_*.pth'))
                ckpt_list.sort(key=os.path.getmtime)

                if ckpt_list.__len__() >= max_ckpt_save_num:
                    for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
                        os.remove(ckpt_list[cur_file_idx])

                ckpt_name = ckpt_save_dir / ('checkpoint_epoch_%d' % trained_epoch)
                save_checkpoint(
                    checkpoint_state(model, optimizer, trained_epoch, accumulated_iter), filename=ckpt_name,
                )


def model_state_to_cpu(model_state):
    model_state_cpu = type(model_state)()  # ordered dict
    for key, val in model_state.items():
        model_state_cpu[key] = val.cpu()
    return model_state_cpu


def checkpoint_state(model=None, optimizer=None, epoch=None, it=None):
    optim_state = optimizer.state_dict() if optimizer is not None else None
    if model is not None:
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model_state = model_state_to_cpu(model.module.state_dict())
        else:
            model_state = model.state_dict()
    else:
        model_state = None

    try:
        import pcdet
        version = 'pcdet+' + pcdet.__version__
    except:
        version = 'none'

    return {'epoch': epoch, 'it': it, 'model_state': model_state, 'optimizer_state': optim_state, 'version': version}


def save_checkpoint(state, filename='checkpoint'):
    if False and 'optimizer_state' in state:
        optimizer_state = state['optimizer_state']
        state.pop('optimizer_state', None)
        optimizer_filename = '{}_optim.pth'.format(filename)
        torch.save({'optimizer_state': optimizer_state}, optimizer_filename)

    filename = '{}.pth'.format(filename)
    torch.save(state, filename)
