import os, logging, csv, numpy as np, wandb
from tqdm import tqdm
import torch, torch.nn as nn
from torch import distributed as dist
from torch.utils.tensorboard import SummaryWriter
from openpoints_student.utils import set_random_seed, save_checkpoint, load_checkpoint, load_checkpoint_inv, resume_checkpoint, setup_logger_dist, \
    cal_model_parm_nums, Wandb
from openpoints_student.utils import AverageMeter, ConfusionMatrix, get_mious
from openpoints_student.dataset import build_dataloader_from_cfg
from openpoints_student.transforms import build_transforms_from_cfg
from openpoints_student.optim import build_optimizer_from_cfg
from openpoints_student.scheduler import build_scheduler_from_cfg
# from openpoints_student.loss import build_criterion_from_cfg
from openpoints_student.models_student import build_model_student_from_cfg
from openpoints.models import build_model_from_cfg
from openpoints_student.models_student.layers import furthest_point_sample, fps


def get_features_by_keys(input_features_dim, data):
    if input_features_dim == 3:
        features = data['pos']
    elif input_features_dim == 4:
        features = torch.cat(
            (data['pos'], data['heights']), dim=-1)
        raise NotImplementedError("error")
    return features.transpose(1, 2).contiguous()

def accuracy(outputs, targets, topk=(1, )):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = targets.size(0)

        _, pred = outputs.topk(maxk, 1, True, True) #torch.Size([2, 5, 24000])
        pred = pred.t()
        correct = pred.eq(targets.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
            res.append(correct_k.mul_(1 / batch_size))
    return res

class search_loss(nn.Module):
    def __init__(self, alpha, beta, network_loss_func, gamma):
        super(search_loss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.beta = beta
        self.network_loss_func = network_loss_func

    def forward(self, output, teacher_output, cost, target, rkd_loss):
        #rkd_loss.requires_grad=True
        #prob1 = F.softmax(teacher_output, dim=-1)
        #prob2 = F.softmax(output, dim=-1)
        #kl = ((prob1 * torch.log(1e-6 + prob1/(prob2+1e-6))).sum(1)).mean()
        kl = 0.0
        if rkd_loss!=None:
           ce_loss = self.network_loss_func(output, target)
           #network_loss = ce_loss + self.gamma*rkd_loss
           network_loss = ce_loss + self.gamma*rkd_loss + kl 
        else:
           network_loss = self.network_loss_func(output, target)
        latency_loss = torch.log(torch.max(cost) ** self.beta)
        loss = self.alpha * network_loss * latency_loss
        if rkd_loss!=None:
           return loss, ce_loss, rkd_loss, kl, latency_loss
        else:
           return loss, network_loss, latency_loss


def write_to_csv(oa, macc, accs, best_epoch, cfg_s, write_header=True):
    accs_table = [f'{item:.2f}' for item in accs]
    header = ['method', 'OA', 'mAcc'] + \
        cfg_s.classes + ['best_epoch', 'log_path', 'wandb link']
    data = [cfg_s.exp_name, f'{oa:.3f}', f'{macc:.2f}'] + accs_table + [
        str(best_epoch), cfg_s.run_dir, wandb.run.get_url() if cfg_s.wandb.use_wandb else '-']
    with open(cfg_s.csv_path, 'a', encoding='UTF8', newline='') as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(header)
        writer.writerow(data)
        f.close()


def print_cls_results(oa, macc, accs, epoch, cfg_s):
    s = f'\nClasses\tAcc\n'
    for name, acc_tmp in zip(cfg_s.classes, accs):
        s += '{:10}: {:3.2f}%\n'.format(name, acc_tmp)
    s += f'E@{epoch}\tOA: {oa:3.2f}\tmAcc: {macc:3.2f}\n'
    logging.info(s)


def main(gpu, cfg,cfg_s, profile=False):
    if cfg_s.distributed:
        if cfg_s.mp:
            cfg_s.rank = gpu
        dist.init_process_group(backend=cfg_s.dist_backend,
                                init_method=cfg_s.dist_url,
                                world_size=cfg_s.world_size,
                                rank=cfg_s.rank)
        dist.barrier()
    # logger
    setup_logger_dist(cfg_s.log_path, cfg_s.rank, name=cfg_s.dataset.common.NAME)
    if cfg_s.rank == 0 :
        Wandb.launch(cfg_s, cfg_s.wandb.use_wandb)
        writer = SummaryWriter(log_dir=cfg_s.run_dir)
    else:
        writer = None
    set_random_seed(cfg_s.seed + cfg_s.rank, deterministic=cfg_s.deterministic)
    torch.backends.cudnn.enabled = True
    logging.info(cfg_s)

    if not cfg.model.get('criterion_args', False):
        cfg.model.criterion_args = cfg.criterion_args
    if not cfg_s.model.get('criterion_args', False):
        cfg_s.model.criterion_args = cfg_s.criterion_args
    model = build_model_student_from_cfg(cfg_s.model).to(cfg_s.rank)
    teacher_model = build_model_from_cfg(cfg.model).to(cfg.rank)
    model_size = cal_model_parm_nums(model)
    logging.info(model)
    logging.info('Number of params: %.4f M' % (model_size / 1e6))
    # criterion = build_criterion_from_cfg(cfg.criterion_args).cuda()
    if cfg.model.get('in_channels', None) is None:
        cfg.model.in_channels = cfg.model.encoder_args.in_channels
    if cfg_s.model.get('in_channels', None) is None:
        cfg_s.model.in_channels = cfg_s.model.encoder_args.in_channels
    if cfg.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        logging.info('Using Synchronized BatchNorm ...')
    # if cfg_s.sync_bn:
    #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    #     logging.info('Using Synchronized BatchNorm ...')
    if cfg_s.distributed:
        torch.cuda.set_device(gpu)
        model = nn.parallel.DistributedDataParallel(
            model.cuda(), device_ids=[cfg_s.rank], output_device=cfg_s.rank)
        logging.info('Using Distributed Data parallel ...')

    # optimizer & scheduler
    optimizer = build_optimizer_from_cfg(model, lr=cfg.lr, **cfg.optimizer)
    scheduler = build_scheduler_from_cfg(cfg, optimizer)

    # build dataset
    val_loader = build_dataloader_from_cfg(cfg.get('val_batch_size', cfg.batch_size),
                                           cfg.dataset,
                                           cfg.dataloader,
                                           datatransforms_cfg=cfg.datatransforms,
                                           split='val',
                                           distributed=cfg.distributed
                                           )
    logging.info(f"length of validation dataset: {len(val_loader.dataset)}")
    test_loader = build_dataloader_from_cfg(cfg.get('val_batch_size', cfg.batch_size),
                                            cfg.dataset,
                                            cfg.dataloader,
                                            datatransforms_cfg=cfg.datatransforms,
                                            split='test',
                                            distributed=cfg.distributed
                                            )
    num_classes = val_loader.dataset.num_classes if hasattr(
        val_loader.dataset, 'num_classes') else None
    num_points = val_loader.dataset.num_points if hasattr(
        val_loader.dataset, 'num_points') else None
    if num_classes is not None:
        assert cfg.num_classes == num_classes
    logging.info(f"number of classes of the dataset: {num_classes}, "
                 f"number of points sampled from dataset: {num_points}, "
                 f"number of points as model input: {cfg.num_points}")
    cfg.classes = cfg.get('classes', None) or val_loader.dataset.classes if hasattr(
        val_loader.dataset, 'classes') else None or np.range(num_classes)
    validate_fn = eval(cfg.get('val_fn', 'validate'))
    path = os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))))
    model_path = os.path.join(path,cfg.load_path)
    load_checkpoint(teacher_model, pretrained_path=model_path)

    train_loss_func = nn.CrossEntropyLoss(reduction='mean')
    # optionally resume from a checkpoint
    if cfg.pretrained_path is not None:
        if cfg.mode == 'resume':
            resume_checkpoint(cfg, model, optimizer, scheduler,
                              pretrained_path=cfg.pretrained_path)
            macc, oa, accs, cm = validate_fn(model, val_loader, cfg_s)
            print_cls_results(oa, macc, accs, cfg.start_epoch, cfg)
        else:
            if cfg.mode == 'test':
                # test mode
                epoch, best_val = load_checkpoint(
                    model, pretrained_path=cfg.pretrained_path)
                macc, oa, accs, cm = validate_fn(model, test_loader, cfg)
                print_cls_results(oa, macc, accs, epoch, cfg)
                return True
            elif cfg.mode == 'val':
                # validation mode
                epoch, best_val = load_checkpoint(model, cfg.pretrained_path)
                macc, oa, accs, cm = validate_fn(model, val_loader, cfg_s)
                print_cls_results(oa, macc, accs, epoch, cfg)
                return True
            elif cfg.mode == 'finetune':
                # finetune the whole model
                logging.info(f'Finetuning from {cfg.pretrained_path}')
                load_checkpoint(model, cfg.pretrained_path)
            elif cfg.mode == 'finetune_encoder':
                # finetune the whole model
                logging.info(f'Finetuning from {cfg.pretrained_path}')
                load_checkpoint(model.encoder, cfg.pretrained_path)
            elif cfg.mode == 'finetune_encoder_inv':
                # finetune the whole model
                logging.info(f'Finetuning from {cfg.pretrained_path}')
                load_checkpoint_inv(model.encoder, cfg.pretrained_path)
    else:
        logging.info('Training from scratch')
    train_loader = build_dataloader_from_cfg(cfg.batch_size,
                                             cfg.dataset,
                                             cfg.dataloader,
                                             datatransforms_cfg=cfg.datatransforms,
                                             split='train',
                                             distributed=cfg.distributed,
                                             )
    logging.info(f"length of training dataset: {len(train_loader.dataset)}")

    # ===> start training
    val_macc, val_oa, val_accs, best_val, macc_when_best, best_epoch = 0., 0., [], 0., 0., 0
    temperature = 5.0
    model.zero_grad()
    for epoch in range(cfg.start_epoch, cfg.epochs + 1):
        if cfg.distributed:
            train_loader.sampler.set_epoch(epoch)
        if hasattr(train_loader.dataset, 'epoch'):
            train_loader.dataset.epoch = epoch - 1
        train_loss, train_macc, train_oa, _, _,cost_meter,network_loss_meter,rkd_loss_meter,kl_loss_meter,acc1_meter,acc5_meter,effective_chanel = \
            train_one_epoch(model, train_loader,
                            optimizer, scheduler, epoch, cfg_s,teacher_model,train_loss_func,temperature,writer)
        temperature = temperature * np.exp(-0.045)
        is_best = False
        if epoch % cfg.val_freq == 0:
            val_macc, val_oa, val_accs, val_cm = validate_fn(
                model, val_loader, cfg_s,teacher_model, temperature)
            is_best = val_oa > best_val
            if is_best:
                best_val = val_oa
                macc_when_best = val_macc
                best_epoch = epoch
                logging.info(f'Find a better ckpt @E{epoch}')
                print_cls_results(val_oa, val_macc, val_accs, epoch, cfg)

        lr = optimizer.param_groups[0]['lr']
        logging.info(f'Epoch {epoch} LR {lr:.6f} '
                     f'train_oa {train_oa:.2f}, val_oa {val_oa:.2f}, best val oa {best_val:.2f}')
        if writer is not None:
            writer.add_scalar('train_loss', train_loss, epoch)
            writer.add_scalar('train_oa', train_macc, epoch)
            writer.add_scalar('lr', lr, epoch)
            writer.add_scalar('val_oa', val_oa, epoch)
            writer.add_scalar('mAcc_when_best', macc_when_best, epoch)
            writer.add_scalar('best_val', best_val, epoch)
            writer.add_scalar('epoch', epoch, epoch)
            writer.add_scalar('cost',cost_meter,epoch)
            writer.add_scalar('network_loss',network_loss_meter,epoch)
            writer.add_scalar('intermediate-kl',rkd_loss_meter,epoch)
            writer.add_scalar('kl',kl_loss_meter,epoch)
            writer.add_scalar('acc@1',acc1_meter,epoch)
            writer.add_scalar('acc@5',acc5_meter,epoch)
            writer.add_scalar('effective_chanel',effective_chanel,epoch)

        if cfg.sched_on_epoch:
            scheduler.step(epoch)
        if cfg.rank == 0:
            save_checkpoint(cfg, model, epoch, optimizer, scheduler,
                            additioanl_dict={'best_val': best_val},
                            is_best=is_best
                            )
    # test the last epoch
    test_macc, test_oa, test_accs, test_cm = validate(model, test_loader, cfg_s,teacher_model, temperature)
    print_cls_results(test_oa, test_macc, test_accs, best_epoch, cfg)
    if writer is not None:
        writer.add_scalar('test_oa', test_oa, epoch)
        writer.add_scalar('test_macc', test_macc, epoch)

    # test the best validataion model
    best_epoch, _ = load_checkpoint(model, pretrained_path=os.path.join(
        cfg.ckpt_dir, f'{cfg.run_name}_ckpt_best.pth'))
    test_macc, test_oa, test_accs, test_cm = validate(model, test_loader, cfg_s,teacher_model, temperature,)
    if writer is not None:
        writer.add_scalar('test_oa', test_oa, best_epoch)
        writer.add_scalar('test_macc', test_macc, best_epoch)
    print_cls_results(test_oa, test_macc, test_accs, best_epoch, cfg)

    if writer is not None:
        writer.close()
    # dist.destroy_process_group()

def train_one_epoch(model, train_loader, optimizer, scheduler, epoch, cfg_s,teacher_model,train_loss_func,temperature,writer):
    loss_meter = AverageMeter()
    cm = ConfusionMatrix(num_classes=cfg_s.num_classes)
    npoints = cfg_s.num_points

    model.train()  # set model to training mode
    pbar = tqdm(enumerate(train_loader), total=train_loader.__len__())
    num_iter = 0
    network_loss_meter = AverageMeter()
    rkd_loss_meter = AverageMeter()
    cost_meter = AverageMeter()
    kl_loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    effective_channel_meter = AverageMeter()
    for idx, data in pbar:
        for key in data.keys():
            data[key] = data[key].cuda(non_blocking=True)
        num_iter += 1
        points = data['x']
        target = data['y']
        """ bebug
        from openpoints.dataset import vis_points
        vis_points(data['pos'].cpu().numpy()[0])
        """
        num_curr_pts = points.shape[1]
        if num_curr_pts > npoints:  # point resampling strategy
            if npoints == 1024:
                point_all = 1200
            elif npoints == 4096:
                point_all = 4800
            elif npoints == 8192:
                point_all = 8192
            else:
                raise NotImplementedError()
            if  points.size(1) < point_all:
                point_all = points.size(1)
            fps_idx = furthest_point_sample(
                points[:, :, :3].contiguous(), point_all)
            fps_idx = fps_idx[:, np.random.choice(
                point_all, npoints, False)]
            points = torch.gather(
                points, 1, fps_idx.unsqueeze(-1).long().expand(-1, -1, points.shape[-1]))

        data['pos'] = points[:, :, :3].contiguous()
        data['x'] = points[:, :, :cfg_s.model.in_channels].transpose(1, 2).contiguous()
        with torch.no_grad():
                all_output,output_layer= teacher_model(data) 
                teacher_output = all_output
        logits, loss,cost_accumulate, kl_accumulate,effective_channel = model.get_logits_loss(data, target,output_layer,temperature) if not hasattr(model, 'module') else model.module.get_logits_loss(data, target,output_layer,temperature)
        search_loss_func = search_loss(0.2, 0.5, train_loss_func, 1.0)
        loss, network_loss, rkd_loss, kl_loss, latency_loss = search_loss_func \
                (logits, teacher_output, cost_accumulate, target, kl_accumulate)
        loss, network_loss, rkd_loss, kl_loss, latency_loss = loss.mean(), \
                network_loss.mean(), rkd_loss.mean(), kl_loss, latency_loss.mean()
        loss.backward()

        # optimize
        if num_iter == cfg_s.step_per_update:
            if cfg_s.get('grad_norm_clip') is not None and cfg_s.grad_norm_clip > 0.:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), cfg_s.grad_norm_clip, norm_type=2)
            num_iter = 0
            optimizer.step()
            model.zero_grad()
            if not cfg_s.sched_on_epoch:
                scheduler.step(epoch)

        acc1, acc5 = accuracy(
                            (logits.view(1, logits.size(1),-1)).squeeze().transpose(1,0),
                            target.view(-1),
                            topk=(1, 5))
        effective_channel=effective_channel.item()
        effective_channel_meter.update(effective_channel)
        writer.add_scalar('effective_channel', effective_channel_meter.avg, idx)
        rkd_loss = rkd_loss.item()
        # kl_loss = kl_loss.item()
        cost = latency_loss.item()
        network_loss = network_loss.item()
        acc1 = acc1.item()
        acc5 = acc5.item()

        num = logits.view(-1, 40).size(0)
        # loss_meter.update(loss, num)
        rkd_loss_meter.update(rkd_loss, num)
        kl_loss_meter.update(kl_loss, num)
        cost_meter.update(cost, num)
        network_loss_meter.update(network_loss, num)
        acc1_meter.update(acc1, num)
        acc5_meter.update(acc5, num)
        # update confusion matrix
        cm.update(logits.argmax(dim=1), target)
        loss_meter.update(loss.item())
        if idx % cfg_s.print_freq == 0:
            pbar.set_description(f"Train Epoch [{epoch}/{cfg_s.epochs}] "
                                f"Loss {loss_meter.val:.3f} Acc {cm.overall_accuray:.2f}"
                                f'cost {cost_meter.val:.4f} ({cost_meter.avg:.4f}) '
                                f'network_loss {network_loss_meter.val:.4f} ({network_loss_meter.avg:.4f}) '
                                f'intermediate-kl {rkd_loss_meter.val:.8f} ({rkd_loss_meter.avg:.8f}) '
                                f'rkd_loss {0.2*rkd_loss_meter.val:.4f} ({0.2*rkd_loss_meter.avg:.4f}) '
                                f'kl {kl_loss_meter.val:.4f} ({kl_loss_meter.avg:.4f}) '
                                f'acc@1 {acc1_meter.val:.4f} ({acc1_meter.avg:.4f}) '
                                f'acc@5 {acc5_meter.val:.4f} ({acc5_meter.avg:.4f})'
                                )
    macc, overallacc, accs = cm.all_acc()
    return loss_meter.avg, macc, overallacc, accs, cm,cost_meter.avg,network_loss_meter.avg,rkd_loss_meter.avg,kl_loss_meter.avg,acc1_meter.avg,acc5_meter.avg,effective_channel_meter.avg


@torch.no_grad()
def validate(model, val_loader, cfg_s, teacher_model, temperature,):
    model.eval()  # set model to eval mode
    cm = ConfusionMatrix(num_classes=cfg_s.num_classes)
    npoints = cfg_s.num_points
    pbar = tqdm(enumerate(val_loader), total=val_loader.__len__())
    for idx, data in pbar:
        for key in data.keys():
            data[key] = data[key].cuda(non_blocking=True)
        target = data['y']
        points = data['x']
        points = points[:, :npoints]
        data['pos'] = points[:, :, :3].contiguous()
        data['x'] = points[:, :, :cfg_s.model.in_channels].transpose(1, 2).contiguous()
        with torch.no_grad():
                # all_output = teacher_model(data).squeeze(0)
                all_output,output_layer = teacher_model(data)
        logits,cost_accumulate, kl_accumulate,effective_channel = model(data,output_layer,temperature)
        cm.update(logits.argmax(dim=1), target)

    tp, count = cm.tp, cm.count
    if cfg_s.distributed:
        dist.all_reduce(tp), dist.all_reduce(count)
    macc, overallacc, accs = cm.cal_acc(tp, count)
    return macc, overallacc, accs, cm
