import pathlib
import json
import os

import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch import FloatTensor
import dgl
from dgl.data.utils import load_graphs
from tqdm import tqdm
from abc import abstractmethod

import random

import numpy as np
from scipy.spatial.transform import Rotation

import argparse
import time
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything

from EquiCAD.EquiNN import Classification


def bounding_box_uvgrid(inp: torch.Tensor):
    pts = inp[..., :3].reshape((-1, 3))
    mask = inp[..., 6].reshape(-1)
    point_indices_inside_faces = mask == 1
    pts = pts[point_indices_inside_faces, :]
    return bounding_box_pointcloud(pts)


def bounding_box_pointcloud(pts: torch.Tensor):
    x = pts[:, 0]
    y = pts[:, 1]
    z = pts[:, 2]
    box = [[x.min(), y.min(), z.min()], [x.max(), y.max(), z.max()]]
    return torch.tensor(box)


def center_and_scale_uvgrid(inp: torch.Tensor, return_center_scale=False):
    bbox = bounding_box_uvgrid(inp)
    diag = bbox[1] - bbox[0]
    scale = 2.0 / max(diag[0], diag[1], diag[2])
    center = 0.5 * (bbox[0] + bbox[1])
    inp[..., :3] -= center
    inp[..., :3] *= scale
    if return_center_scale:
        return inp, center, scale
    return inp


def get_random_rotation():
    """Get a random rotation in 90 degree increments along the canonical axes"""
    axes = [
        np.array([1, 0, 0]),
        np.array([0, 1, 0]),
        np.array([0, 0, 1]),
    ]
    angles = [0.0, 90.0, 180.0, 270.0]
    axis = random.choice(axes)
    angle_radians = np.radians(random.choice(angles))
    return Rotation.from_rotvec(angle_radians * axis)


def rotate_uvgrid(inp, rotation):
    """Rotate the node features in the graph by a given rotation"""
    Rmat = torch.tensor(rotation.as_matrix()).float()
    orig_size = inp[..., :3].size()
    inp[..., :3] = torch.mm(inp[..., :3].reshape(-1, 3), Rmat).reshape(orig_size)
    inp[..., 3:6] = torch.mm(inp[..., 3:6].reshape(-1, 3), Rmat).reshape(orig_size)
    return inp


INVALID_FONTS = [
    "Bokor",
    "Lao Muang Khong",
]


def valid_font(filename):
    for name in INVALID_FONTS:
        if name.lower() in str(filename).lower():
            return False
    return True


class BaseDataset(Dataset):
    @staticmethod
    @abstractmethod
    def num_classes():
        pass

    def load_graphs(self, file_paths, center_and_scale=True):
        self.data = []
        for fn in tqdm(file_paths):
            if not fn.exists():
                continue
            sample = self.load_one_graph(fn)
            if sample is None:
                continue
            if sample["graph"].edata["x"].size(0) == 0:
                # Catch the case of graphs with no edges
                continue
            self.data.append(sample)
        if center_and_scale:
            self.center_and_scale()
        self.convert_to_float32()
    
    def load_one_graph(self, file_path):
        graph = load_graphs(str(file_path))[0][0]
        sample = {"graph": graph, "filename": file_path.stem}
        return sample

    def center_and_scale(self):
        for i in range(len(self.data)):
            self.data[i]["graph"].ndata["x"], center, scale = center_and_scale_uvgrid(
                self.data[i]["graph"].ndata["x"], return_center_scale=True
            )
            self.data[i]["graph"].edata["x"][..., :3] -= center
            self.data[i]["graph"].edata["x"][..., :3] *= scale

    def convert_to_float32(self):
        for i in range(len(self.data)):
            self.data[i]["graph"].ndata["x"] = self.data[i]["graph"].ndata["x"].type(FloatTensor)
            self.data[i]["graph"].edata["x"] = self.data[i]["graph"].edata["x"].type(FloatTensor)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if hasattr(self, 'random_rotate') and hasattr(self, 'split') and self.random_rotate and self.split == "train":
            rotation = get_random_rotation()
            sample["graph"].ndata["x"] = rotate_uvgrid(sample["graph"].ndata["x"], rotation)
            sample["graph"].edata["x"] = rotate_uvgrid(sample["graph"].edata["x"], rotation)
        return sample

    def _collate(self, batch):
        batched_graph = dgl.batch([sample["graph"] for sample in batch])
        batched_filenames = [sample["filename"] for sample in batch]
        return {"graph": batched_graph, "filename": batched_filenames}

    def get_dataloader(self, batch_size=128, shuffle=True, num_workers=0):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self._collate,
            num_workers=num_workers,
            drop_last=True,
        )


def _get_filenames(root_dir, filelist):
    with open(str(root_dir / f"{filelist}"), "r") as f:
        file_list = [x.strip() for x in f.readlines()]

    files = list(
        x
        for x in root_dir.rglob(f"*.bin")
        if x.stem in file_list
    )
    return files


class IndustrialParts(BaseDataset):
    CATEGORY_TO_LABEL = {}
    
    @staticmethod
    def extract_category(filename):
        if isinstance(filename, pathlib.Path):
            name = filename.stem
        else:
            name = filename
        try:
            category = name.split('_')[0]
            return category
        except:
            print(f"警告：无法从'{name}'中提取类别")
            return "未知"
    
    @classmethod
    def num_classes(cls):
        if not cls.CATEGORY_TO_LABEL:
            print("警告：类别映射尚未构建，返回默认值16")
            return 16  
        return len(cls.CATEGORY_TO_LABEL)
    
    def __init__(
        self,
        root_dir,
        split="train",
        center_and_scale=True,
        random_rotate=False,
    ):
        assert split in ("train", "val", "test")
        path = pathlib.Path(root_dir)
        
        self.split = split
        self.random_rotate = random_rotate
        
        self._build_category_mapping(path)
        
        if split in ("train", "val"):
            file_paths = _get_filenames(path, filelist="train.txt")
            labels = [IndustrialParts.CATEGORY_TO_LABEL.get(self.extract_category(fn.stem), 0) for fn in file_paths]
            train_files, val_files = train_test_split(
                file_paths, test_size=0.2, random_state=42, stratify=labels,
            )
            if split == "train":
                file_paths = train_files
            elif split == "val":
                file_paths = val_files
        elif split == "test":
            file_paths = _get_filenames(path, filelist="test.txt")
        
        print(f"Loading {split} data...")
        self.load_graphs(file_paths, center_and_scale)
        print("Done loading {} files".format(len(self.data)))
        if split == "train":
            print(f"类别映射 ({len(IndustrialParts.CATEGORY_TO_LABEL)}个类): {IndustrialParts.CATEGORY_TO_LABEL}")
    
    def _build_category_mapping(self, root_dir):
        """扫描数据集构建类别映射"""
        if IndustrialParts.CATEGORY_TO_LABEL: 
            return
        
        all_files = []
        for filelist in ["train.txt", "test.txt"]:
            try:
                with open(str(root_dir / filelist), "r") as f:
                    file_list = [x.strip() for x in f.readlines()]
                files = list(x for x in root_dir.rglob(f"*.bin") if x.stem in file_list)
                all_files.extend(files)
            except FileNotFoundError:
                continue
        
        categories = set()
        for file_path in all_files:
            category = self.extract_category(file_path.stem)
            categories.add(category)
        
        sorted_categories = sorted(categories)
        for i, category in enumerate(sorted_categories):
            IndustrialParts.CATEGORY_TO_LABEL[category] = i
        
        print(f"发现{len(IndustrialParts.CATEGORY_TO_LABEL)}个类别: {sorted_categories}")
    
    def load_one_graph(self, file_path):
        try:
            sample = super().load_one_graph(file_path)
            category = self.extract_category(file_path.stem)
            
            if category not in IndustrialParts.CATEGORY_TO_LABEL:
                print(f"警告：文件'{file_path.stem}'的类别'{category}'未在映射中找到，使用默认值0")
                label = 0
            else:
                label = IndustrialParts.CATEGORY_TO_LABEL[category]
                
            sample["label"] = torch.tensor([label]).long()
            return sample
        except Exception as e:
            print(f"处理文件'{file_path}'时出错: {e}")
            return None
    
    def _collate(self, batch):
        collated = super()._collate(batch)
        collated["label"] = torch.cat([x["label"] for x in batch], dim=0)
        return collated


parser = argparse.ArgumentParser("UV-Net solid model classification")
parser.add_argument(
    "traintest", choices=("train", "test"), help="Whether to train or test"
)
parser.add_argument("--dataset_path", type=str, help="Path to dataset")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument(
    "--num_workers",
    type=int,
    default=0,
    help="Number of workers for the dataloader. NOTE: set this to 0 on Windows, any other value leads to poor performance",
)
parser.add_argument(
    "--checkpoint",
    type=str,
    default=None,
    help="Checkpoint file to load weights from for testing",
)
parser.add_argument(
    "--experiment_name",
    type=str,
    default="classification",
    help="Experiment name (used to create folder inside ./results/ to save logs and checkpoints)",
)
parser.add_argument(
    "--load_checkpoint",
    type=str,
    default="False",
    help="Path to pre-trained checkpoint file or 'False' to create a new model",
)

parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()

results_path = (
    pathlib.Path(__file__).parent.joinpath("results").joinpath(args.experiment_name)
)
if not results_path.exists():
    results_path.mkdir(parents=True, exist_ok=True)

month_day = time.strftime("%m%d")
hour_min_second = time.strftime("%H%M%S")
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=str(results_path.joinpath(month_day, hour_min_second)),
    filename="best",
    save_last=True,
)

trainer = Trainer.from_argparse_args(
    args,
    callbacks=[checkpoint_callback],
    logger=TensorBoardLogger(
        str(results_path), name=month_day, version=hour_min_second,
    ),
    gpus=1 if torch.cuda.is_available() else None
)
Dataset = IndustrialParts

if args.traintest == "train":
    torch.cuda.empty_cache()
    seed_everything(workers=True)
    print(
        f"""
-----------------------------------------------------------------------------------
UV-Net Classification
-----------------------------------------------------------------------------------
Logs written to results/{args.experiment_name}/{month_day}/{hour_min_second}

To monitor the logs, run:
tensorboard --logdir results/{args.experiment_name}/{month_day}/{hour_min_second}

The trained model with the best validation loss will be written to:
results/{args.experiment_name}/{month_day}/{hour_min_second}/best.ckpt
-----------------------------------------------------------------------------------
    """
    )

    train_data_probe = Dataset(root_dir=args.dataset_path, split="train", center_and_scale=False)
    mapping_dir = str(results_path.joinpath(month_day, hour_min_second))
    os.makedirs(mapping_dir, exist_ok=True)
    mapping_file = os.path.join(mapping_dir, "category_mapping.json") 
    with open(mapping_file, "w", encoding="utf-8") as f:
        json.dump(IndustrialParts.CATEGORY_TO_LABEL, f, ensure_ascii=False, indent=4)
    print(f"类别映射已保存到: {mapping_file}")

    if args.load_checkpoint != "False":
        print(f"Loading pre-trained model from {args.load_checkpoint}...")
        model = Classification.load_from_checkpoint(args.load_checkpoint)
    else:
        model = Classification(num_classes=Dataset.num_classes())
        print(f"创建模型，类别数: {Dataset.num_classes()}")

    train_data = Dataset(root_dir=args.dataset_path, split="train")
    val_data = Dataset(root_dir=args.dataset_path, split="val")
    train_loader = train_data.get_dataloader(
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
    )
    val_loader = val_data.get_dataloader(
        batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
    )
    trainer.fit(model, train_loader, val_loader)
    

else:
    assert (
        args.checkpoint is not None
    ), "Expected the --checkpoint argument to be provided"
    
    checkpoint_dir = os.path.dirname(args.checkpoint)
    mapping_file = os.path.join(checkpoint_dir, "category_mapping.json")
    if os.path.exists(mapping_file):
        with open(mapping_file, "r", encoding="utf-8") as f:
            IndustrialParts.CATEGORY_TO_LABEL = json.load(f)
        print(f"已加载类别映射: {IndustrialParts.CATEGORY_TO_LABEL}")
        print(f"类别数: {IndustrialParts.num_classes()}")
    else:
        print("警告：找不到类别映射文件！将依赖数据集自动检测，可能导致测试类别不匹配。")
    
    model = Classification.load_from_checkpoint(args.checkpoint)

    test_data = Dataset(root_dir=args.dataset_path, split="test")
    test_loader = test_data.get_dataloader(
        batch_size=64, shuffle=False, num_workers=args.num_workers
    )
    results = trainer.test(model=model, test_dataloaders=[test_loader], verbose=False)
    
    print("Test results:", results)
    if 'test_acc' in results[0]:
        print(f"Classification accuracy (%) on test set: {results[0]['test_acc'] * 100.0}")
    elif 'test_acc_epoch' in results[0]:
        print(f"Classification accuracy (%) on test set: {results[0]['test_acc_epoch'] * 100.0}")
    else:
        print("Available keys:", list(results[0].keys()))
        acc_keys = [k for k in results[0].keys() if 'acc' in k.lower()]
        if acc_keys:
            print(f"Classification accuracy (%) on test set: {results[0][acc_keys[0]] * 100.0}")
        else:
            print("无法找到准确率相关的键，请查看上面的可用键列表并手动提取")