import numpy as np
import torch
from . import common_utils
from torchvision.transforms.functional import rotate, affine
import pickle

def random_flip_along_x(gt_occ, points):
    """
    Args:
        gt_occ: [x, y, z] or [x,y]
        points: (M, 3 + C)
    Returns:
    """
    enable = np.random.choice([True, False], replace=False, p=[0.5, 0.5])
    if enable:
        gt_occ = np.flip(gt_occ,  axis=1)
        points[:, 1] = -points[:, 1]
        
    return gt_occ, points

def random_flip_along_y(gt_occ, points):
    """
    Args:
        gt_occ: [x, y, z] or [x,y]
        points: (M, 3 + C)
    Returns:
    """
    enable = np.random.choice([True, False], replace=False, p=[0.5, 0.5])
    if enable:
        gt_occ = np.flip(gt_occ,  axis=0)
        points[:, 0] = -points[:, 0]
        
    return gt_occ, points

def random_rotate90(gt_occ, points):
    """
    Args:
        gt_occ: [N, x, y, z] or [N.x,y]
        points: (M, 3 + C)
    Returns:
    """
    enable = np.random.choice([True, False], replace=False, p=[0.5, 0.5])
    if enable:
        gt_occ = np.rot90(gt_occ,  1)
        points[:, [0, 1]] = points[:, [1, 0]]

        points[:, 0] = -points[:, 0]
        
    return gt_occ, points

def balanced_infos_resampling(dataset_cfg, infos, logger):

    with open("./tools/occ_infos.pkl", "rb") as f:
        occ_infos = pickle.load(f)
    class_names_balance = dataset_cfg.BALANCED_RESAMPLING_CLASS
    
    cls_infos = {name: [] for name in class_names_balance}
    for info in infos:
        occ_info = occ_infos[info['frame_id']]
        for cls_ind in range(dataset_cfg.CLASS_NUM):
            if occ_info[cls_ind] != 0 and (cls_ind in class_names_balance):
                cls_infos[cls_ind].append(info)

    duplicated_samples = sum([len(v) for _, v in cls_infos.items()])
    cls_dist = {k: len(v) / duplicated_samples for k, v in cls_infos.items()}

    sampled_infos = []

    frac = 1.0 / len(class_names_balance)
    ratios = np.array([frac / v for v in cls_dist.values()])
    ratios[ratios>1] = np.power(ratios[ratios>1], 1/2)
    ratios = ratios.tolist()

    for cur_cls_infos, ratio in zip(list(cls_infos.values()), ratios):
        sampled_infos += np.random.choice(
            cur_cls_infos, int(len(cur_cls_infos) * ratio)
        ).tolist()
    logger.info('Total samples after balanced resampling: %s' % (len(sampled_infos)))

    cls_infos_new = {name: [] for name in class_names_balance}
    for info in sampled_infos:
        for name in set(info['annos']['name']):
            if name in class_names_balance:
                cls_infos_new[name].append(info)

    cls_dist_new = {k: len(v) / len(sampled_infos) for k, v in cls_infos_new.items()}

    return sampled_infos
