import datetime
import os
import yaml
import numpy as np
import torch
import random
import matplotlib.pyplot as plt


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def init_np_seed(worker_id):
    seed = torch.initial_seed()
    np.random.seed(seed % 4294967296)


def seed_torch(seed: int = 42) -> None:
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)  # 为了禁止hash随机化，使得实验可复现
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)
    print(f"Random seed set as {seed}")


def load_yaml(config_path):
    assert os.path.exists(config_path), f"wrong config path: {config_path}"
    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


def save_results(res_tar1, res_tar2, res_big_tar, src, model_path):
    auroc1, fpr1, aupr_in1, aupr_out1 = res_tar1['auroc'], res_tar1['fpr_at_95_tpr'], res_tar1['aupr_in'], res_tar1['aupr_out']
    auroc2, fpr2, aupr_in2, aupr_out2 = res_tar2['auroc'], res_tar2['fpr_at_95_tpr'], res_tar2['aupr_in'], res_tar2['aupr_out']
    auroc3, fpr3, aupr_in3, aupr_out3 = res_big_tar['auroc'], res_big_tar['fpr_at_95_tpr'], res_big_tar['aupr_in'], res_big_tar['aupr_out']
    with open('log.txt', 'a+', newline='\n') as f:
        f.write(f"SRC: {src}, infer model: {model_path}\n")
        f.write(f"log time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"SRC->TAR1:      AUROC: {auroc1:.4f}, FPR95: {fpr1:.4f}, AUPR_IN: {aupr_in1:.4f}, AUPR_OUT: {aupr_out1:.4f}\n")
        f.write(f"SRC->TAR2:      AUROC: {auroc2:.4f}, FPR95: {fpr2:.4f}, AUPR_IN: {aupr_in2:.4f}, AUPR_OUT: {aupr_out2:.4f}\n")
        f.write(f"SRC->TAR1+TAR2: AUROC: {auroc3:.4f}, FPR95: {fpr3:.4f}, AUPR_IN: {aupr_in3:.4f}, AUPR_OUT: {aupr_out3:.4f}\n")
        f.write("\n")
    f.close()


def format_time(time) -> str:
    """
    Normalize to standard time format -----> hh:mm:ss
    """
    elapsed_rounded = int(round((time)))
    return str(datetime.timedelta(seconds=elapsed_rounded))


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    scheduler = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    scheduler = np.concatenate((warmup_schedule, scheduler))
    assert len(scheduler) == epochs * niter_per_ep
    return scheduler


def rescale(x, a, b):
    k = (b - a) / (max(x) - min(x))
    return a + k * (x - min(x))


def to_numpy(t):
    if torch.is_tensor(t):
        return t.data.cpu().numpy()
    elif type(t).__module__ == 'numpy':
        return t
    elif isinstance(t, list):
        return np.asarray(t)
    else:
        raise ValueError(f"t is {type(t)}")


def make_random_sphere_grid(n):
    vec = np.random.randn(3, n)  # sample from normal distribution
    vec /= np.linalg.norm(vec, axis=0)  # normalize vector
    random_sphere = np.array(vec, dtype=np.float32)  # (3, n)
    return random_sphere


def show_point_cloud(points):
    """
    Show a point cloud by using plt.
    :param points: np.ndarray (3, N)
    """
    ax = plt.axes(projection='3d', aspect='equal')
    ax.scatter3D(points[0], points[1], points[2], s=3)  # (3, n)
    ax.set_box_aspect([1, 1, 1])
    plt.show()


def show_depth_map(img):
    """
    :param img: tensor with size of [B * num_views, 3, 224, 224]
    """
    views = img.shape[0]
    plt.figure("depth map", figsize=(5, 5))

    for i in range(views):
        plt.subplot(3, 4, i + 1)
        plt.title(f"view{i + 1}")
        plt.imshow(img[i, 0, ...], cmap='gray')
        plt.xticks([])
        plt.yticks([])
        plt.axis('off')
    plt.show()

