import os, csv, time, numpy as np
from pickle import LIST
from tqdm import tqdm
import logging
import torch
from torch import distributed as dist
from torch_scatter import scatter
from openpoints.utils import set_random_seed
from openpoints.utils import ConfusionMatrix, get_mious
from openpoints.dataset import get_scene_seg_features
from openpoints.dataset.data_util import voxelize
from openpoints.transforms import build_transforms_from_cfg
# --- Hierarchy utils ---
from utils.hierarchy import (
    get_fine2coarse_tensor,
    fine_pred_to_coarse_pred,
    get_hierarchical_label_maps,
)


def write_to_csv(oa, macc, miou, ious, best_epoch, cfg, write_header=True, area=5):
    ious_table = [f'{item:.2f}' for item in ious]
    # NOTE: Because multiprocess will cause "wandb.run.get_url()" error, check weather "cfg.world_size==1" to decide write wandb link.
    # header = ['method', 'Area', 'OA', 'mACC', 'mIoU'] + cfg.classes + ['best_epoch', 'log_path', 'wandb link']
    # data = [cfg.cfg_basename, str(area), f'{oa:.2f}', f'{macc:.2f}', f'{miou:.2f}'] + ious_table + \
    #     [str(best_epoch), cfg.run_dir, wandb.run.get_url() if cfg.wandb.use_wandb and cfg.rank==0 else '-']
    header = ['method', 'Area', 'OA', 'mACC', 'mIoU'] + cfg.classes + ['best_epoch', 'log_path']
    data = [cfg.cfg_basename, str(area), f'{oa:.2f}', f'{macc:.2f}', f'{miou:.2f}'] + ious_table + \
        [str(best_epoch), cfg.run_dir]
    with open(cfg.csv_path, 'a', encoding='UTF8', newline='') as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(header)
        writer.writerow(data)
        f.close()

@torch.no_grad()
def test_s3dis(model, area, cfg, global_cm=None):
    """using a part of original point cloud as input to save memory.
    Args:
        model (_type_): _description_
        test_loader (_type_): _description_
        cfg (_type_): _description_
        global_cm (_type_, optional): _description_. Defaults to None.
        num_votes (int, optional): _description_. Defaults to 1.
    Returns:
        _type_: _description_
    """
    model.eval()  # set model to eval mode
    cm = ConfusionMatrix(num_classes=cfg.num_classes)
    global_cm =  ConfusionMatrix(num_classes=cfg.num_classes) if global_cm is None else global_cm
    set_random_seed(0)
    cfg.visualize = cfg.get('visualize', False)
        
    # data
    trans_split = 'val' if cfg.datatransforms.get('test', None) is None else 'test'
    transform =  build_transforms_from_cfg(trans_split, cfg.datatransforms)

    raw_root = os.path.join(cfg.dataset.common.data_root, 'raw')
    data_list = sorted(os.listdir(raw_root))
    data_list = [item[:-4] for item in data_list if 'Area_' in item]
    data_list = [item for item in data_list if 'Area_{}'.format(area) in item]

    voxel_size =  cfg.dataset.common.voxel_size
    # ------ Hierarchical confusion matrix init (once) ------
    if cfg.get('hier_coarse_num', cfg.get('coarse_num_classes', None)) is not None:
        cm_coarse = ConfusionMatrix(num_classes=cfg.get('hier_coarse_num', cfg.get('coarse_num_classes', None)))
        fine2coarse_tensor = get_fine2coarse_tensor(device=cfg.rank)
    for cloud_idx, item in enumerate(tqdm(data_list)):
        data_path = os.path.join(raw_root, item + '.npy')
        cdata = np.load(data_path).astype(np.float32)  # xyz, rgb, label, N*7
        coord_min = np.min(cdata[:, :3], 0)
        cdata[:, :3] -= coord_min
        label = torch.from_numpy(cdata[:, 6].astype(np.int64).squeeze()).cuda(non_blocking=True)
        colors = np.clip(cdata[:, 3:6] / 255., 0, 1).astype(np.float32)

        all_logits, all_point_inds = [], []
        if voxel_size is not None:
            uniq_idx, count = voxelize(cdata[:, :3], voxel_size, mode=1)
            for i in range(count.max()):
                idx_select = np.cumsum(
                    np.insert(count, 0, 0)[0:-1]) + i % count
                idx_part = uniq_idx[idx_select]
                np.random.shuffle(idx_part)
                all_point_inds.append(idx_part)
                coord, feat = cdata[idx_part][:,0:3] - np.min(cdata[idx_part][:, :3], 0), cdata[idx_part][:, 3:6]

                data = {'pos': coord, 'x': feat}
                if transform is not None:
                    data = transform(data)
                if 'heights' in data.keys():
                    data['x'] = torch.cat((data['x'], data['heights']), dim=1)
                else:
                    data['x'] = torch.cat((data['x'], torch.from_numpy(
                        coord[:, 3-cfg.dataset.common.get('n_shifted', 1):3].astype(np.float32))), dim=-1)

                if not cfg.dataset.common.get('variable', False):
                    data['x'] = data['x'].unsqueeze(0)
                    data['pos'] = data['pos'].unsqueeze(0)
                else:
                    data['o'] = torch.IntTensor([len(coord)])

                for key in data:
                    data[key] = data[key].cuda(non_blocking=True)
                if 'student_model' in cfg.keys(): # when kd, need to use student model cfg
                    in_channels = cfg.student_model.in_channels if 'in_channels' in cfg.student_model.keys() \
                        else cfg.student_model.encoder_args.in_channels
                else:
                    in_channels = cfg.model.in_channels if 'in_channels' in cfg.model.keys() else cfg.model.encoder_args.in_channels
                data['x'] = get_scene_seg_features(in_channels, data['pos'], data['x'])

                logits_out = model(data)

                # 支持多种返回结构：tuple/list → (fine_logits, coarse_logits, ...)
                if isinstance(logits_out, (tuple, list)):
                    fine_logits = logits_out[0]
                    coarse_logits = None
                    if len(logits_out) > 1 and logits_out[1] is not None:
                        # 粗分支输出尺寸匹配时才认为是真 coarse logits
                        if logits_out[1].shape[-1] == cfg.get('hier_coarse_num', cfg.get('coarse_num_classes', None)):
                            coarse_logits = logits_out[1]
                else:
                    fine_logits = logits_out
                    coarse_logits = None

                if isinstance(fine_logits, list):
                    fine_logits = fine_logits[-1]
                if coarse_logits is not None and isinstance(coarse_logits, list):
                    coarse_logits = coarse_logits[-1]

                all_logits.append((fine_logits, coarse_logits))

            # 处理聚合后的预测
            fine_list, coarse_list = zip(*all_logits)  # list of tensors
            fine_tensor = torch.cat(fine_list, dim=0)
            if not cfg.dataset.common.get('variable', False) and len(fine_tensor.size()) > 2:
                fine_tensor = fine_tensor.transpose(1, 2).reshape(-1, cfg.num_classes)

            if cfg.get('hier_coarse_num', cfg.get('coarse_num_classes', None)) is not None and any(c is not None for c in coarse_list):
                # 至少有一块返回了 coarse_logits → 使用 coarse_logits 聚合
                coarse_valid = [c for c in coarse_list if c is not None]
                coarse_tensor = torch.cat(coarse_valid, dim=0)
            else:
                coarse_tensor = None  # fallback later

            all_point_inds = torch.from_numpy(np.hstack(all_point_inds)).cuda(non_blocking=True)
            # project voxel subsampled to original set
            fine_tensor = scatter(fine_tensor, all_point_inds, dim=0, reduce='mean')
            if coarse_tensor is not None:
                coarse_tensor = scatter(coarse_tensor, all_point_inds, dim=0, reduce='mean')
            all_point_inds = scatter(all_point_inds, all_point_inds, dim=0, reduce='mean')

            # --- Confusion matrix update ---
            pred_fine = fine_tensor.argmax(dim=1)
            cm.update(pred_fine, label)
            global_cm.update(pred_fine, label)

            if cfg.get('hier_coarse_num', cfg.get('coarse_num_classes', None)) is not None:
                if coarse_tensor is not None:
                    pred_coarse = coarse_tensor.argmax(dim=1)
                else:
                    pred_coarse = fine_pred_to_coarse_pred(pred_fine, fine2coarse_tensor)
                label_coarse = fine_pred_to_coarse_pred(label, fine2coarse_tensor)
                cm_coarse.update(pred_coarse, label_coarse)
    
    tp, union, count = cm.tp, cm.union, cm.count
    miou, macc, oa, ious, accs = get_mious(tp, union, count)

    if cfg.get('hier_coarse_num', cfg.get('coarse_num_classes', None)) is not None:
        tp_c, union_c, count_c = cm_coarse.tp, cm_coarse.union, cm_coarse.count
        if cfg.distributed:
            dist.all_reduce(tp_c); dist.all_reduce(union_c); dist.all_reduce(count_c)
        miou_c, macc_c, oa_c, ious_c, _ = get_mious(tp_c, union_c, count_c)
        logging.info(f"[TEST coarse-level] mIoU {miou_c:.2f} mAcc {macc_c:.2f} OA {oa_c:.2f}")
        with np.printoptions(precision=2, suppress=True):
            logging.info(f"Coarse IoU per cls: {ious_c}")

    return miou, macc, oa, ious, accs, global_cm


@torch.no_grad()
def test_s3dis_hierarchical(model, area, cfg):
    """
    Generic hierarchical scene-level evaluation for S3DIS.
    Dynamically adapts to the number of hierarchy levels output by the model.
    """
    model.eval()
    set_random_seed(0)

    # Dynamically build label mapping functions for each level from fine ground truth
    # Assumes cfg.dataset.common.h_matrix_list_file points to the hierarchy definitions
    # The last map in the list is the identity map (fine -> fine)
    label_maps = get_hierarchical_label_maps(cfg, device=cfg.rank)
    num_levels = len(label_maps)

    cm_list = []  # List to hold a confusion matrix for each hierarchy level

    # Data loading setup (remains specific to S3DIS protocol)
    trans_split = 'val' if cfg.datatransforms.get('test', None) is None else 'test'
    transform = build_transforms_from_cfg(trans_split, cfg.datatransforms)
    raw_root = os.path.join(cfg.dataset.common.data_root, 'raw')
    data_list = sorted([f for f in os.listdir(raw_root) if f"Area_{area}" in f and f.endswith('.npy')])

    logging.info(f"Testing on Area_{area} with {len(data_list)} scenes...")

    for item in tqdm(data_list, desc=f"Area {area}"):
        data_path = os.path.join(raw_root, item)
        scene_data = np.load(data_path).astype(np.float32)

        coord_min = np.min(scene_data[:, :3], 0)
        scene_data[:, :3] -= coord_min

        # Finest-level ground truth labels
        labels_fine = torch.from_numpy(scene_data[:, 6].astype(np.int64)).cuda(non_blocking=True)

        all_logits_per_level = [[] for _ in range(num_levels)]
        all_point_indices = []

        # Voxel-based voting strategy
        voxel_size = cfg.dataset.common.get('voxel_size', 0.04) # Default to 0.04 if not set
        uniq_idx, counts = voxelize(scene_data[:, :3], voxel_size, mode=1)

        for i in range(counts.max()):
            idx_select = np.cumsum(np.insert(counts, 0, 0)[:-1]) + i % counts
            idx_part = uniq_idx[idx_select]
            np.random.shuffle(idx_part)
            all_point_indices.append(idx_part)

            coords_part = scene_data[idx_part][:, :3]
            feats_part = scene_data[idx_part][:, 3:6]

            sample = {'pos': coords_part, 'x': feats_part}
            if transform:
                sample = transform(sample)

            sample['x'] = torch.cat((
                sample['x'],
                torch.from_numpy(coords_part[:, :3-cfg.dataset.common.get('n_shifted',1)].astype(np.float32))
            ), dim=-1)

            sample['pos'] = sample['pos'].unsqueeze(0)
            sample['x'] = sample['x'].unsqueeze(0)
            for k in sample:
                sample[k] = sample[k].cuda(non_blocking=True)

            in_channels = cfg.model.get('in_channels', cfg.model.encoder_args.in_channels)
            sample['x'] = get_scene_seg_features(in_channels, sample['pos'], sample['x'])

            # --- Inference ---
            logits_list = model(sample)
            if not isinstance(logits_list, (list, tuple)):
                logits_list = [logits_list]

            # Initialize confusion matrices on first pass
            if not cm_list:
                for lg in logits_list:
                    cm_list.append(ConfusionMatrix(num_classes=lg.size(1)))

            for lvl in range(len(logits_list)):
                all_logits_per_level[lvl].append(logits_list[lvl].squeeze(0))

        # --- Aggregate results for the entire scene ---
        all_point_indices = torch.from_numpy(np.hstack(all_point_indices)).cuda(non_blocking=True)

        aggregated_logits = []
        for lvl_logits in all_logits_per_level:
            if not lvl_logits: continue
            concatenated = torch.cat(lvl_logits, dim=0)
            aggregated = scatter(concatenated, all_point_indices, dim=0, reduce='mean')
            aggregated_logits.append(aggregated)

        # --- Update confusion matrices for each level ---
        for lvl, lg in enumerate(aggregated_logits):
            # Derive GT for the current level
            labels_lvl = label_maps[lvl](labels_fine)
            preds_lvl = lg.argmax(dim=1)
            cm_list[lvl].update(preds_lvl, labels_lvl)

    # --- Calculate and log final metrics for each level ---
    all_metrics = []
    for lvl, cm in enumerate(cm_list):
        tp, union, count = cm.tp, cm.union, cm.count
        if cfg.distributed:
            dist.all_reduce(tp)
            dist.all_reduce(union)
            dist.all_reduce(count)

        mi, ma, oa, ious, accs = get_mious(tp, union, count)
        all_metrics.append({'miou': mi, 'macc': ma, 'oa': oa, 'ious': ious, 'accs': accs})

    return all_metrics


# ---------------------------------------------------------------
# Generic loader-based hierarchical evaluation (dataset-agnostic)
# ---------------------------------------------------------------

@torch.no_grad()
def evaluate_loader_hierarchical(model, data_loader, cfg):
    """Evaluate *any* scene/patch dataloader and return metrics for each
    hierarchy level.  Dataset must provide dict with key ``'y'`` – if the
    label tensor has shape (N,) it will be treated as *finest* level only;
    if it has shape (N,L) it is assumed to contain L levels coarse→fine.

    Parameters
    ----------
    model : torch.nn.Module
        Trained segmentation model which returns a list of logits per level.
    data_loader : DataLoader
        Loader yielding dict samples compatible with model forward.
    cfg : EasyConfig
        Global config (for distributed / mask handling).

    Returns
    -------
    list[dict]
        For each level a dict with keys ``miou, macc, oa, ious``.
    """
    model.eval()

    cm_list = []

    pbar = tqdm(data_loader, total=len(data_loader), desc="Eval")
    for data in pbar:
        # Move tensors to GPU & basic preprocessing (mirrors validate())
        if ('semantic3d' in cfg.dataset.common.NAME.lower()) or ('toronto3d' in cfg.dataset.common.NAME.lower()):
            data['mask'] = ~(data['y'] == 0)
        elif 'scannetv2' in cfg.dataset.common.NAME.lower():
            data['mask'] = ~(data['y'] == 255)

        keys = data.keys() if callable(data.keys) else data.keys
        for k in keys:
            if isinstance(data[k], list):
                for i in range(len(data[k])):
                    if torch.is_tensor(data[k][i]):
                        data[k][i] = data[k][i].cuda(non_blocking=True)
            elif torch.is_tensor(data[k]):
                data[k] = data[k].cuda(non_blocking=True)

        # Feature construction (reuse helper)
        in_channels = cfg.model.get('in_channels', cfg.model.encoder_args.in_channels)
        data['x'] = get_scene_seg_features(in_channels, data['pos'], data['x'])

        outputs = model(data)
        if isinstance(outputs, list):
            logits_list = outputs
        elif isinstance(outputs, tuple) and isinstance(outputs[0], list):
            logits_list = outputs[0]
        else:  # single level
            logits_list = [outputs]

        # Init confusion matrices on first batch
        if not cm_list:
            for lg in logits_list:
                cm_list.append(ConfusionMatrix(num_classes=lg.size(1), ignore_index=None))

        y_all = data['y']
        if y_all.dim() == 1:
            y_all = y_all.unsqueeze(-1)

        for lvl, lg in enumerate(logits_list):
            tgt_lvl = y_all[..., lvl] if y_all.size(-1) > lvl else y_all[..., -1]
            if lg.dim() == 2 and tgt_lvl.dim() > 1:  # dataloader variable batch
                tgt_lvl = torch.cat([t.squeeze() for t in tgt_lvl.split(1, 0)])

            pred_lvl = lg.argmax(dim=1)
            if data.get('mask', None) is not None:
                mask = data['mask']
                pred_lvl = pred_lvl[mask]
                tgt_lvl = tgt_lvl[mask]

            cm_list[lvl].update(pred_lvl, tgt_lvl.reshape(-1))

    # Compute metrics per level
    metrics = []
    for cm in cm_list:
        tp, union, count = cm.tp, cm.union, cm.count
        if cfg.distributed:
            dist.all_reduce(tp); dist.all_reduce(union); dist.all_reduce(count)
        mi, ma, oa, ious, accs = get_mious(tp, union, count)
        metrics.append({'miou': mi, 'macc': ma, 'oa': oa, 'ious': ious, 'accs': accs})

    return metrics
