import pathlib
import string

import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch import FloatTensor
import dgl
from dgl.data.utils import load_graphs
from tqdm import tqdm
from abc import abstractmethod

import random

import numpy as np
from scipy.spatial.transform import Rotation

import argparse
import time
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything

from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt


def bounding_box_uvgrid(inp: torch.Tensor):
    pts = inp[..., :3].reshape((-1, 3))
    mask = inp[..., 6].reshape(-1)
    point_indices_inside_faces = mask == 1
    pts = pts[point_indices_inside_faces, :]
    return bounding_box_pointcloud(pts)


def bounding_box_pointcloud(pts: torch.Tensor):
    x = pts[:, 0]
    y = pts[:, 1]
    z = pts[:, 2]
    box = [[x.min(), y.min(), z.min()], [x.max(), y.max(), z.max()]]
    return torch.tensor(box)


def center_and_scale_uvgrid(inp: torch.Tensor, return_center_scale=False):
    bbox = bounding_box_uvgrid(inp)
    diag = bbox[1] - bbox[0]
    scale = 2.0 / max(diag[0], diag[1], diag[2])
    center = 0.5 * (bbox[0] + bbox[1])
    inp[..., :3] -= center
    inp[..., :3] *= scale
    if return_center_scale:
        return inp, center, scale
    return inp


def get_random_rotation():
    """Get a random rotation in 90 degree increments along the canonical axes"""
    axes = [
        np.array([1, 0, 0]),
        np.array([0, 1, 0]),
        np.array([0, 0, 1]),
    ]
    angles = [0.0, 90.0, 180.0, 270.0]
    axis = random.choice(axes)
    angle_radians = np.radians(random.choice(angles))
    return Rotation.from_rotvec(angle_radians * axis)


def rotate_uvgrid(inp, rotation):
    """Rotate the node features in the graph by a given rotation"""
    Rmat = torch.tensor(rotation.as_matrix()).float()
    orig_size = inp[..., :3].size()
    inp[..., :3] = torch.mm(inp[..., :3].reshape(-1, 3), Rmat).reshape(orig_size)
    inp[..., 3:6] = torch.mm(inp[..., 3:6].reshape(-1, 3), Rmat).reshape(orig_size)
    return inp


INVALID_FONTS = [
    "Bokor",
    "Lao Muang Khong",
    "Lao Sans Pro",
    "MS Outlook",
    "Catamaran Black",
    "Dubai",
    "HoloLens MDL2 Assets",
    "Lao Muang Don",
    "Oxanium Medium",
    "Rounded Mplus 1c",
    "Moul Pali",
    "Noto Sans Tamil",
    "Webdings",
    "Armata",
    "Koulen",
    "Yinmar",
    "Ponnala",
    "Noto Sans Tamil",
    "Chenla",
    "Lohit Devanagari",
    "Metal",
    "MS Office Symbol",
    "Cormorant Garamond Medium",
    "Chiller",
    "Give You Glory",
    "Hind Vadodara Light",
    "Libre Barcode 39 Extended",
    "Myanmar Sans Pro",
    "Scheherazade",
    "Segoe MDL2 Assets",
    "Siemreap",
    "Signika SemiBold" "Taprom",
    "Times New Roman TUR",
    "Playfair Display SC Black",
    "Poppins Thin",
    "Raleway Dots",
    "Raleway Thin",
    "Segoe MDL2 Assets",
    "Segoe MDL2 Assets",
    "Spectral SC ExtraLight",
    "Txt",
    "Uchen",
    "Yinmar",
    "Almarai ExtraBold",
    "Fasthand",
    "Exo",
    "Freckle Face",
    "Montserrat Light",
    "Inter",
    "MS Reference Specialty",
    "MS Outlook",
    "Preah Vihear",
    "Sitara",
    "Barkerville Old Face",
    "Bodoni MT" "Bokor",
    "Fasthand",
    "HoloLens MDL2 Assests",
    "Libre Barcode 39",
    "Lohit Tamil",
    "Marlett",
    "MS outlook",
    "MS office Symbol Semilight",
    "MS office symbol regular",
    "Ms office symbol extralight",
    "Ms Reference speciality",
    "Segoe MDL2 Assets",
    "Siemreap",
    "Sitara",
    "Symbol",
    "Wingdings",
    "Metal",
    "Ponnala",
    "Webdings",
    "Souliyo Unicode",
    "Aguafina Script",
    "Yantramanav Black",
    # "Yaldevi",
    # Taprom,
    # "Zhi Mang Xing",
    # "Taviraj",
    # "SeoulNamsan EB",
]


def valid_font(filename):
    for name in INVALID_FONTS:
        if name.lower() in str(filename).lower():
            return False
    return True


class BaseDataset(Dataset):
    @staticmethod
    @abstractmethod
    def num_classes():
        pass

    def load_graphs(self, file_paths, center_and_scale=True):
        self.data = []
        for fn in tqdm(file_paths):
            if not fn.exists():
                continue
            sample = self.load_one_graph(fn)
            if sample is None:
                continue
            if sample["graph"].edata["x"].size(0) == 0:
                # Catch the case of graphs with no edges
                continue
            self.data.append(sample)
        if center_and_scale:
            self.center_and_scale()
        self.convert_to_float32()
    
    def load_one_graph(self, file_path):
        graph = load_graphs(str(file_path))[0][0]
        sample = {"graph": graph, "filename": file_path.stem}
        return sample

    def center_and_scale(self):
        for i in range(len(self.data)):
            self.data[i]["graph"].ndata["x"], center, scale = center_and_scale_uvgrid(
                self.data[i]["graph"].ndata["x"], return_center_scale=True
            )
            self.data[i]["graph"].edata["x"][..., :3] -= center
            self.data[i]["graph"].edata["x"][..., :3] *= scale

    def convert_to_float32(self):
        for i in range(len(self.data)):
            self.data[i]["graph"].ndata["x"] = self.data[i]["graph"].ndata["x"].float()
            self.data[i]["graph"].edata["x"] = self.data[i]["graph"].edata["x"].float()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.random_rotate and self.split == "train":
            rotation = get_random_rotation()
            sample["graph"].ndata["x"] = rotate_uvgrid(sample["graph"].ndata["x"], rotation)
            sample["graph"].edata["x"] = rotate_uvgrid(sample["graph"].edata["x"], rotation)
        return sample

    def _collate(self, batch):
        batched_graph = dgl.batch([sample["graph"] for sample in batch])
        batched_filenames = [sample["filename"] for sample in batch]
        return {"graph": batched_graph, "filename": batched_filenames}

    def get_dataloader(self, batch_size=128, shuffle=True, num_workers=0):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self._collate,
            num_workers=num_workers,
            drop_last=True,
        )


def _get_filenames(root_dir, filelist):
    with open(str(root_dir / f"{filelist}"), "r") as f:
        file_list = [x.strip() for x in f.readlines()]

    files = list(
        x
        for x in root_dir.rglob(f"*.bin")
        if x.stem in file_list
        #if util.valid_font(x) and x.stem in file_list
    )
    return files


CHAR2LABEL = {char: i for (i, char) in enumerate(string.ascii_lowercase)}


class BinaryFeatureDataset(BaseDataset):
    @staticmethod
    def num_classes():
        return 2  # 二分类问题

    def __init__(
        self,
        root_dir,
        split="train",
        center_and_scale=True,
        random_rotate=False,
    ):
        """
        加载二分类特征数据集
        
        Args:
            root_dir (str): 数据集根目录路径
            split (str, optional): 加载的数据集分割(train, val, 或 test)。默认为 "train"。
            center_and_scale (bool, optional): 是否对模型进行居中和缩放。默认为 True。
            random_rotate (bool, optional): 是否以90度增量对模型应用随机旋转。默认为 False。
        """
        assert split in ("train", "val", "test")
        path = pathlib.Path(root_dir)
        
        self.random_rotate = random_rotate
        self.split = split  # 保存split以便在__getitem__中使用

        if split in ("train", "val"):
            file_paths = _get_filenames(path, filelist="train.txt")
            # 从文件名前缀判断标签
            labels = [1 if fn.stem.startswith("1_") else 0 for fn in file_paths]
            train_files, val_files = train_test_split(
                file_paths, test_size=0.11111111, random_state=66, stratify=labels,
            )
            if split == "train":
                file_paths = train_files
            elif split == "val":
                file_paths = val_files
        elif split == "test":
            file_paths = _get_filenames(path, filelist="test.txt")

        print(f"Loading {split} data...")
        self.load_graphs(file_paths, center_and_scale)
        print(f"Done loading {len(self.data)} files")

    def load_one_graph(self, file_path):
        # 使用基类方法加载图
        sample = super().load_one_graph(file_path)
        # 从文件名确定标签（1_开头为1，0_开头为0）
        label = 1 if file_path.stem.startswith("1_") else 0
        sample["label"] = torch.tensor([label]).long()
        return sample

    def _collate(self, batch):
        collated = super()._collate(batch)
        collated["label"] = torch.cat([x["label"] for x in batch], dim=0)
        return collated


parser = argparse.ArgumentParser("UV-Net solid model classification")
parser.add_argument(
    "traintest", choices=("train", "test"), help="Whether to train or test"
)
parser.add_argument("--dataset", choices=("features",), help="Dataset to train on")
parser.add_argument("--dataset_path", type=str, help="Path to dataset")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for the dataloader.")
parser.add_argument("--experiment_name", type=str, default="classification", help="Experiment name")
parser.add_argument("--ckpt_dir", type=str, default=None, help="手动指定保存/加载ckpt的文件夹")
parser.add_argument("--repeat", type=int, default=3, help="每个模型重复次数")
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()

results_path = (
    pathlib.Path(__file__).parent.joinpath("Results").joinpath(args.experiment_name)
)
if not results_path.exists():
    results_path.mkdir(parents=True, exist_ok=True)

month_day = time.strftime("%m%d")
hour_min_second = time.strftime("%H%M%S")

if args.ckpt_dir is not None:
    save_dir = pathlib.Path(args.ckpt_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
else:
    save_dir = results_path.joinpath(month_day, hour_min_second)
    save_dir.mkdir(parents=True, exist_ok=True)

def get_model(model_type, num_classes):
    if model_type == "enn":
        from EquiCAD.EquiNN import Classification
        return Classification(num_classes=num_classes)
    elif model_type == "uvnet":
        from uvnet.models import Classification
        return Classification(num_classes=num_classes)

def get_model_from_ckpt(model_type, checkpoint):
    if model_type == "enn":
        from EquiCAD.EquiNN import Classification
        return Classification.load_from_checkpoint(checkpoint)
    elif model_type == "uvnet":
        from uvnet.models import Classification
        return Classification.load_from_checkpoint(checkpoint)

def evaluate_model_with_auc(model, test_loader, device):
    """评估模型并计算AUC指标"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            graph = batch["graph"].to(device)
            labels = batch["label"].to(device)

            if hasattr(graph, "ndata") and "x" in graph.ndata:
                # 假设 graph.ndata["x"] shape: (B, H, W, 7) or (B, L, 6)
                if graph.ndata["x"].dim() == 4:
                    # (B, H, W, 7) -> (B, 7, H, W)
                    graph.ndata["x"] = graph.ndata["x"].permute(0, 3, 1, 2)
                elif graph.ndata["x"].dim() == 3:
                    # (B, L, 6) -> (B, 6, L)
                    graph.ndata["x"] = graph.ndata["x"].permute(0, 2, 1)
            if hasattr(graph, "edata") and "x" in graph.edata:
                # (B, L, 6) -> (B, 6, L)
                if graph.edata["x"].dim() == 3:
                    graph.edata["x"] = graph.edata["x"].permute(0, 2, 1)
            
            outputs = model(graph)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            
            all_preds.append(probs[:, 1].cpu().numpy())  # 取正类的概率
            all_labels.append(labels.cpu().numpy())
    
    # 合并所有批次的预测和标签
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    
    # 计算AUC
    roc_auc = roc_auc_score(all_labels, all_preds)
    
    # 计算ROC曲线
    fpr, tpr, _ = roc_curve(all_labels, all_preds)
    
    return roc_auc, fpr, tpr, all_preds, all_labels

if args.dataset == "features":
    Dataset = BinaryFeatureDataset
else:
    raise ValueError("Unsupported dataset")

if args.traintest == "train":
    torch.cuda.empty_cache()
    print(
        f"""
-----------------------------------------------------------------------------------
UV-Net Classification (BOTH)
-----------------------------------------------------------------------------------
Logs written to {save_dir}

To monitor the logs, run:
tensorboard --logdir {save_dir}

The trained models with the best validation loss will be written to:
{save_dir}/best.enn
{save_dir}/best.uvnet
-----------------------------------------------------------------------------------
    """
    )

    train_data = Dataset(root_dir=args.dataset_path, split="train")
    val_data = Dataset(root_dir=args.dataset_path, split="val")
    train_loader = train_data.get_dataloader(
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
    )
    val_loader = val_data.get_dataloader(
        batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
    )

    # 训练 ENN
    for repeat_idx in range(args.repeat):
        enn_seed = random.randint(0, 2**31 - 1)
        print(f"第 {repeat_idx+1} 次训练 ENN，随机种子：{enn_seed}")
        seed_everything(enn_seed, workers=True)
        enn_ckpt = str(save_dir.joinpath(f"best.enn.{repeat_idx}.ckpt"))
        enn_callback = ModelCheckpoint(
            monitor="val_loss",
            dirpath=str(save_dir),
            filename=f"best.enn.{repeat_idx}",
            save_last=True,
        )
        enn_trainer = Trainer.from_argparse_args(
            args,
            callbacks=[enn_callback],
            logger=TensorBoardLogger(str(save_dir), name=f"enn_{repeat_idx}"),
            gpus=1 if torch.cuda.is_available() else None,
            precision=32,
        )
        enn_model = get_model("enn", Dataset.num_classes())
        enn_trainer.fit(enn_model, train_loader, val_loader)

    # # 训练 UVNet
    # for repeat_idx in range(args.repeat):
    #     uvnet_seed = random.randint(0, 2**31 - 1)
    #     print(f"第 {repeat_idx+1} 次训练 UVNet，随机种子：{uvnet_seed}")
    #     seed_everything(uvnet_seed, workers=True)
    #     uvnet_ckpt = str(save_dir.joinpath(f"best.uvnet.{repeat_idx}.ckpt"))
    #     uvnet_callback = ModelCheckpoint(
    #         monitor="val_loss",
    #         dirpath=str(save_dir),
    #         filename=f"best.uvnet.{repeat_idx}",
    #         save_last=True,
    #     )
    #     uvnet_trainer = Trainer.from_argparse_args(
    #         args,
    #         callbacks=[uvnet_callback],
    #         logger=TensorBoardLogger(str(save_dir), name=f"uvnet_{repeat_idx}"),
    #         gpus=1 if torch.cuda.is_available() else None,
    #         precision=32,
    #     )
    #     uvnet_model = get_model("uvnet", Dataset.num_classes())
    #     uvnet_trainer.fit(uvnet_model, train_loader, val_loader)

elif args.traintest == "test":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    test_data = Dataset(root_dir=args.dataset_path, split="test")
    test_loader = test_data.get_dataloader(
        batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
    )

    enn_aucs = []
    uvnet_aucs = []
    enn_fprs = []
    enn_tprs = []
    uvnet_fprs = []
    uvnet_tprs = []
    mean_fpr = np.linspace(0, 1, 100)

    for repeat_idx in range(args.repeat):
        enn_ckpt = str(save_dir.joinpath(f"best.enn.{repeat_idx}.ckpt"))
        uvnet_ckpt = str(save_dir.joinpath(f"best.uvnet.{repeat_idx}.ckpt"))

        enn_model = get_model_from_ckpt("enn", enn_ckpt)
        uvnet_model = get_model_from_ckpt("uvnet", uvnet_ckpt)
        enn_model.to(device)
        uvnet_model.to(device)

        print(f"评估 enn 模型... (第 {repeat_idx+1} 次)")
        enn_auc, enn_fpr, enn_tpr, enn_preds, enn_labels = evaluate_model_with_auc(enn_model, test_loader, device)
        enn_aucs.append(enn_auc)
        enn_fprs.append(np.interp(mean_fpr, enn_fpr, enn_tpr))
        enn_tprs.append(np.interp(mean_fpr, enn_fpr, enn_tpr))

        print(f"评估 UVNet 模型... (第 {repeat_idx+1} 次)")
        uvnet_auc, uvnet_fpr, uvnet_tpr, uvnet_preds, uvnet_labels = evaluate_model_with_auc(uvnet_model, test_loader, device)
        uvnet_aucs.append(uvnet_auc)
        uvnet_fprs.append(np.interp(mean_fpr, uvnet_fpr, uvnet_tpr))
        uvnet_tprs.append(np.interp(mean_fpr, uvnet_fpr, uvnet_tpr))

    # 计算平均TPR和标准差
    mean_enn_tpr = np.mean(enn_tprs, axis=0)
    std_enn_tpr = np.std(enn_tprs, axis=0)
    mean_uvnet_tpr = np.mean(uvnet_tprs, axis=0)
    std_uvnet_tpr = np.std(uvnet_tprs, axis=0)

    pointnet_results = np.load(str(save_dir.joinpath("pointnet_auc_results.npz")))
    pointnet_aucs = pointnet_results["aucs"]
    pointnet_fprs = pointnet_results["fprs"]
    pointnet_tprs = pointnet_results["tprs"]
    mean_fpr = pointnet_results["mean_fpr"]
    mean_pointnet_tpr = np.mean(pointnet_tprs, axis=0)
    std_pointnet_tpr = np.std(pointnet_tprs, axis=0)

    dgcnn_results = np.load(str(save_dir.joinpath("dgcnn_auc_results.npz")))
    dgcnn_aucs = dgcnn_results["aucs"]
    dgcnn_fprs = dgcnn_results["fprs"]
    dgcnn_tprs = dgcnn_results["tprs"]
    mean_fpr = dgcnn_results["mean_fpr"]
    mean_dgcnn_tpr = np.mean(dgcnn_tprs, axis=0)
    std_dgcnn_tpr = np.std(dgcnn_tprs, axis=0)

    pointnet2_results = np.load(str(save_dir.joinpath("pointnet2_auc_results.npz")))
    pointnet2_aucs = pointnet2_results["aucs"]
    pointnet2_fprs = pointnet2_results["fprs"]
    pointnet2_tprs = pointnet2_results["tprs"]
    mean_fpr = pointnet2_results["mean_fpr"]
    mean_pointnet2_tpr = np.mean(pointnet2_tprs, axis=0)
    std_pointnet2_tpr = np.std(pointnet2_tprs, axis=0)

    plt.figure(figsize=(8, 6))
    plt.plot(mean_fpr, mean_enn_tpr, color='blue', lw=2, label=f'ENN Mean ROC (AUC = {np.mean(enn_aucs):.3f})')
    plt.fill_between(mean_fpr, mean_enn_tpr - std_enn_tpr, mean_enn_tpr + std_enn_tpr, color='blue', alpha=0.2)
    plt.plot(mean_fpr, mean_uvnet_tpr, color='red', lw=2, label=f'UVNet Mean ROC (AUC = {np.mean(uvnet_aucs):.3f})')
    plt.fill_between(mean_fpr, mean_uvnet_tpr - std_uvnet_tpr, mean_uvnet_tpr + std_uvnet_tpr, color='red', alpha=0.2)
    plt.plot(mean_fpr, mean_pointnet_tpr, color='green', lw=2, label=f'PointNet Mean ROC (AUC = {np.mean(pointnet_aucs):.3f})')
    plt.fill_between(mean_fpr, mean_pointnet_tpr - std_pointnet_tpr, mean_pointnet_tpr + std_pointnet_tpr, color='green', alpha=0.2)
    plt.plot(mean_fpr, mean_dgcnn_tpr, color='purple', lw=2, label=f'DGCNN Mean ROC (AUC = {np.mean(dgcnn_aucs):.3f})')
    plt.fill_between(mean_fpr, mean_dgcnn_tpr - std_dgcnn_tpr, mean_dgcnn_tpr + std_dgcnn_tpr, color='purple', alpha=0.2)
    plt.plot(mean_fpr, mean_pointnet2_tpr, color='orange', lw=2, label=f'PointNet2 Mean ROC (AUC = {np.mean(pointnet2_aucs):.3f})')
    plt.fill_between(mean_fpr, mean_pointnet2_tpr - std_pointnet2_tpr, mean_pointnet2_tpr + std_pointnet2_tpr, color='orange', alpha=0.2)
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Mean ROC curves (3 runs)')
    plt.legend(loc="lower right")
    roc_file = str(save_dir.joinpath(f"roc_comparison_mean_{month_day}_{hour_min_second}.png"))
    plt.savefig(roc_file)
    plt.close()
    print(f"平均ROC曲线已保存至: {roc_file}")