"""
dataloader.py
电路数据加载和预处理
"""

import torch
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import numpy as np
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import random
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import cosine, euclidean
import logging
import cv2

logger = logging.getLogger(__name__)


class CircuitDataset(Dataset):
    """电路数据集 - 加载单个电路"""
    
    def __init__(self, 
                 data_dir: str,
                 mode: str = 'train',
                 distance_metric: str = 'euclidean',
                 normalize_features: bool = True,
                 scaler: Optional[StandardScaler] = None,
                 matrix_size: Tuple[int, int] = (64, 64)):  # 添加矩阵尺寸参数
        """
        Args:
            data_dir: 数据目录路径
            mode: 'train', 'val', 或 'test'
            distance_metric: 距离度量方式
            normalize_features: 是否归一化特征
            scaler: 外部提供的归一化器（用于验证/测试集）
            matrix_size: 统一的矩阵尺寸
        """
        self.data_dir = Path(data_dir)
        self.mode = mode
        self.distance_metric = distance_metric
        self.normalize_features = normalize_features
        self.matrix_size = matrix_size
        
        # 加载电路数据
        self.circuit_data = []
        self.load_circuit_data()
        
        # 特征归一化
        if normalize_features:
            if mode == 'train' and scaler is None:
                self.scaler = StandardScaler()
                self.fit_scaler()
            else:
                self.scaler = scaler
        else:
            self.scaler = None
        
        logger.info(f"Loaded {len(self.circuit_data)} circuits for {mode} set")
    
    def load_circuit_data(self):
        """扫描并加载所有电路数据"""
        if not self.data_dir.exists():
            logger.warning(f"Data directory {self.data_dir} does not exist")
            return
        
        for circuit_dir in [d for d in self.data_dir.iterdir() if d.is_dir()]:
            json_file = circuit_dir / 'graph.json'
            npy_file = circuit_dir / 'shape.npy'
            vector_file = circuit_dir / 'vector.npy'
            
            if json_file.exists() and npy_file.exists() and vector_file.exists():
                # 加载向量并确保是一维的
                vector = np.load(vector_file)
                # 如果向量是多维的，展平它
                if vector.ndim > 1:
                    vector = vector.flatten()
                    logger.debug(f"Flattened vector for {circuit_dir.name} from shape {vector.shape}")
                
                # 检查矩阵形状
                matrix = np.load(npy_file)
                logger.debug(f"Circuit {circuit_dir.name}: matrix shape = {matrix.shape}, vector shape = {vector.shape}")
                
                self.circuit_data.append({
                    'id': circuit_dir.name,
                    'json_path': json_file,
                    'npy_path': npy_file,
                    'vector_path': vector_file,
                    'vector': vector,
                    'original_matrix_shape': matrix.shape  # 记录原始矩阵形状
                })
    
    def fit_scaler(self):
        """拟合特征归一化器（仅用于训练集）"""
        all_features = []
        for circuit in self.circuit_data:
            node_features, _, _ = self.parse_circuit_json(circuit['json_path'])
            if node_features.size > 0:  # 确保有节点特征
                all_features.append(node_features)
        
        if all_features:
            all_features = np.vstack(all_features)
            self.scaler.fit(all_features)
            logger.info(f"Fitted scaler on {len(all_features)} nodes")
    
    def parse_circuit_json(self, json_path: Path) -> Tuple[np.ndarray, np.ndarray, Dict]:
        """
        解析电路JSON文件
        Returns:
            node_features: 节点特征矩阵
            edge_index: 边索引
            metadata: 额外的元数据
        """
        try:
            with open(json_path, 'r') as f:
                json_data = json.load(f)
            
            nodes = json_data.get('nodes', [])
            edges = json_data.get('edges', [])
            
            # 如果没有节点，返回空数组
            if not nodes:
                logger.warning(f"No nodes found in {json_path}")
                return np.array([], dtype=np.float32).reshape(0, 7), np.array([[], []], dtype=np.int64), {'num_nodes': 0}
            
            # 创建节点ID到索引的映射
            node_id_to_idx = {node['id']: i for i, node in enumerate(nodes)}
            
            # 提取节点特征
            node_features = []
            for node in nodes:
                features = [
                    1.0 if node.get('type', '') == 'cell' else 0.0,  # 节点类型
                    *node.get('location', [0, 0]),                    # 位置
                    node.get('width', 0),                             # 宽度
                    node.get('height', 0),                            # 高度
                    self._encode_rotation(node.get('rotation', 'R0')), # 旋转角度
                    len(node.get('pins', []))                         # 引脚数量
                ]
                node_features.append(features)
            
            node_features = np.array(node_features, dtype=np.float32)
            
            # 提取边（无向图，需要添加双向边）
            edge_list = []
            for edge in edges:
                source = edge.get('source')
                target = edge.get('target')
                if source in node_id_to_idx and target in node_id_to_idx:
                    source_idx = node_id_to_idx[source]
                    target_idx = node_id_to_idx[target]
                    edge_list.extend([[source_idx, target_idx], [target_idx, source_idx]])
            
            edge_index = np.array(edge_list, dtype=np.int64).T if edge_list else np.array([[], []], dtype=np.int64)
            
            metadata = {
                'num_nodes': len(nodes),
                'num_edges': len(edge_list) // 2 if edge_list else 0,
                'circuit_id': json_path.parent.name
            }
            
            return node_features, edge_index, metadata
            
        except Exception as e:
            logger.error(f"Error parsing {json_path}: {e}")
            # 返回空数组
            return np.array([], dtype=np.float32).reshape(0, 7), np.array([[], []], dtype=np.int64), {'num_nodes': 0}
    
    def _encode_rotation(self, rotation: str) -> float:
        """编码旋转角度"""
        rotation_map = {'R0': 0, 'R90': 90, 'R180': 180, 'R270': 270}
        return rotation_map.get(rotation, 0) / 360.0  # 归一化到[0, 1]
    
    def _resize_matrix(self, matrix: np.ndarray) -> np.ndarray:
        """
        调整矩阵到统一尺寸
        Args:
            matrix: 输入矩阵，形状为 (height, width, 2)
        Returns:
            调整后的矩阵，形状为 (2, height, width)
        """
        # 确保矩阵是3D的 (height, width, channels)
        if matrix.ndim != 3:
            raise ValueError(f"Expected 3D matrix (height, width, 2), got {matrix.shape}")
        
        if matrix.shape[2] != 2:
            raise ValueError(f"Expected 2 channels in last dimension, got {matrix.shape[2]}")
        
        height, width, channels = matrix.shape
        target_h, target_w = self.matrix_size
        
        # 如果尺寸已经匹配，直接转置为 (channels, height, width)
        if height == target_h and width == target_w:
            return np.transpose(matrix, (2, 0, 1))
        
        # 调整每个通道的尺寸
        resized_channels = []
        for c in range(channels):
            # 使用双线性插值调整尺寸
            resized = cv2.resize(
                matrix[:, :, c], 
                (target_w, target_h), 
                interpolation=cv2.INTER_AREA
            )
            resized_channels.append(resized)
        
        # 重新组合通道并转置为 (channels, height, width)
        resized_matrix = np.stack(resized_channels, axis=0)
        
        return resized_matrix
    
    def load_circuit(self, idx: int) -> Tuple[Data, torch.Tensor, torch.Tensor]:
        """
        加载单个电路的所有数据
        Returns:
            graph_data: PyG图数据对象
            matrix: 电路矩阵表示（统一尺寸，torch.Tensor）
            vector: 电路向量表示（torch.Tensor）
        """
        circuit = self.circuit_data[idx]

        # 加载并处理图数据
        node_features, edge_index, metadata = self.parse_circuit_json(circuit['json_path'])

        # 特征归一化
        if self.normalize_features and self.scaler and node_features.size > 0:
            node_features = self.scaler.transform(node_features)

        # 如果没有节点，创建一个默认节点
        if node_features.size == 0:
            node_features = np.zeros((1, 7), dtype=np.float32)
            edge_index = np.array([[], []], dtype=np.int64)

        graph_data = Data(
            x=torch.FloatTensor(node_features),
            edge_index=torch.LongTensor(edge_index)
        )

        # 加载矩阵数据
        matrix = np.load(circuit['npy_path']).astype(np.float32)
        matrix = self._resize_matrix(matrix)  # shape: (2, H, W)

        # ✅ 转为 torch.Tensor
        matrix_tensor = torch.from_numpy(matrix)  # [2, H, W]

        # 向量数据
        vector = circuit['vector'].astype(np.float32)
        if vector.ndim > 1:
            vector = vector.flatten()
        vector_tensor = torch.from_numpy(vector)  # 可能是 [D,]

        return graph_data, matrix_tensor, vector_tensor
    
    def calculate_distance(self, v1: np.ndarray, v2: np.ndarray) -> float:
        """计算两个向量之间的距离"""
        # 确保向量是一维的
        v1 = v1.flatten() if v1.ndim > 1 else v1
        v2 = v2.flatten() if v2.ndim > 1 else v2
        
        # 检查向量长度是否相同
        if len(v1) != len(v2):
            logger.warning(f"Vector length mismatch: {len(v1)} vs {len(v2)}")
            # 填充或截断到相同长度
            max_len = max(len(v1), len(v2))
            v1_padded = np.zeros(max_len)
            v2_padded = np.zeros(max_len)
            v1_padded[:len(v1)] = v1
            v2_padded[:len(v2)] = v2
            v1, v2 = v1_padded, v2_padded
        
        if self.distance_metric == 'euclidean':
            return float(euclidean(v1, v2))
        elif self.distance_metric == 'cosine':
            # 处理零向量的情况
            if np.allclose(v1, 0) or np.allclose(v2, 0):
                return 1.0  # 最大距离
            return float(cosine(v1, v2))
        elif self.distance_metric == 'manhattan':
            return float(np.sum(np.abs(v1 - v2)))
        else:
            raise ValueError(f"Unknown distance metric: {self.distance_metric}")
    
    def __len__(self):
        return len(self.circuit_data)
    
    def __getitem__(self, idx):
        return self.load_circuit(idx)


# CircuitPairDataset类保持不变...
class CircuitPairDataset(Dataset):
    """电路对数据集 - 用于训练距离学习"""
    
    def __init__(self, 
                 circuit_dataset: CircuitDataset,
                 num_pairs: Optional[int] = None,
                 sampling_strategy: str = 'random',
                 hard_negative_ratio: float = 0.3):
        """
        Args:
            circuit_dataset: 基础电路数据集
            num_pairs: 生成的配对数量
            sampling_strategy: 采样策略 ('random', 'hard', 'balanced')
            hard_negative_ratio: 困难负样本的比例
        """
        self.circuit_dataset = circuit_dataset
        self.sampling_strategy = sampling_strategy
        self.hard_negative_ratio = hard_negative_ratio
        
        # 检查数据集是否有足够的样本
        if len(circuit_dataset) < 2:
            logger.warning(f"Dataset has only {len(circuit_dataset)} samples, creating dummy pairs")
            self.pairs = [(0, 0)] if len(circuit_dataset) == 1 else []
            self.distance_matrix = np.zeros((len(circuit_dataset), len(circuit_dataset)))
        else:
            # 预计算距离矩阵
            self.precompute_distances()
            # 生成训练对
            self.generate_pairs(num_pairs)
    
    def precompute_distances(self):
        """预计算所有电路对之间的距离"""
        n = len(self.circuit_dataset)
        self.distance_matrix = np.zeros((n, n))
        
        logger.info("Precomputing distance matrix...")
        for i in range(n):
            for j in range(i+1, n):
                try:
                    dist = self.circuit_dataset.calculate_distance(
                        self.circuit_dataset.circuit_data[i]['vector'],
                        self.circuit_dataset.circuit_data[j]['vector']
                    )
                    self.distance_matrix[i, j] = dist
                    self.distance_matrix[j, i] = dist
                except Exception as e:
                    logger.error(f"Error calculating distance between circuit {i} and {j}: {e}")
                    # 使用默认距离
                    self.distance_matrix[i, j] = 1.0
                    self.distance_matrix[j, i] = 1.0
        
        # 归一化距离到[0, 1]
        if self.distance_matrix.max() > 0:
            self.distance_matrix /= self.distance_matrix.max()
        
        logger.info(f"Distance matrix computed: shape={self.distance_matrix.shape}")
    
    def generate_pairs(self, num_pairs: Optional[int]):
        """生成训练对"""
        n = len(self.circuit_dataset)
        
        if n < 2:
            self.pairs = []
            logger.warning("Not enough circuits to generate pairs")
            return
        
        # 随机采样
        all_pairs = [(i, j) for i in range(n) for j in range(i+1, n)]
        if num_pairs and len(all_pairs) > num_pairs:
            all_pairs = random.sample(all_pairs, num_pairs)
        self.pairs = all_pairs
        
        # 如果没有生成足够的对，添加一些重复对
        if len(self.pairs) == 0 and n > 0:
            self.pairs = [(0, min(1, n-1))]
        
        logger.info(f"Generated {len(self.pairs)} pairs with {self.sampling_strategy} sampling")
    
    def __len__(self):
        return max(1, len(self.pairs))  # 至少返回1避免空数据集
    
    def __getitem__(self, idx):
        if len(self.pairs) == 0:
            # 返回一个虚拟对
            graph_data = Data(x=torch.zeros(1, 7), edge_index=torch.LongTensor([[], []]))
            matrix = np.zeros((2, 64, 64))  # 修改为2通道
            return graph_data, matrix, graph_data, matrix, torch.FloatTensor([0.0])
        
        i, j = self.pairs[idx % len(self.pairs)]  # 使用模运算避免越界
        
        # 加载两个电路
        graph1, matrix1, _ = self.circuit_dataset[i]
        graph2, matrix2, _ = self.circuit_dataset[j]
        
        # 获取真实距离
        distance = self.distance_matrix[i, j]
        
        return graph1, matrix1, graph2, matrix2, torch.FloatTensor([distance])


def collate_circuit_pairs(batch):
    """批处理函数 - 将多个电路对组合成批"""
    graphs1, matrices1, graphs2, matrices2, distances = zip(*batch)
    
    # 批处理图数据
    batch_graph1 = Batch.from_data_list(graphs1)
    batch_graph2 = Batch.from_data_list(graphs2)
    
    # 批处理矩阵数据 - 现在所有矩阵应该有相同的形状
    try:
        batch_matrix1 = torch.FloatTensor(np.stack(matrices1))
        batch_matrix2 = torch.FloatTensor(np.stack(matrices2))
    except ValueError as e:
        # 如果仍然有问题，打印调试信息
        logger.error(f"Matrix stacking error: {e}")
        logger.error(f"Matrix shapes: {[m.shape for m in matrices1]}")
        raise
    
    # 批处理距离
    batch_distances = torch.cat(distances)
    
    return batch_graph1, batch_matrix1, batch_graph2, batch_matrix2, batch_distances


def create_dataloaders(config: Dict) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    创建训练、验证和测试数据加载器
    Args:
        config: 配置字典
    Returns:
        train_loader, val_loader, test_loader
    """
    # 获取统一的矩阵尺寸
    matrix_size = tuple(config.get('matrix_size', [64, 64]))
    
    # 创建基础数据集
    train_dataset = CircuitDataset(
        config['train_data_dir'],
        mode='train',
        distance_metric=config.get('distance_metric', 'euclidean'),
        normalize_features=config.get('normalize_features', True),
        matrix_size=matrix_size
    )
    
    val_dataset = CircuitDataset(
        config['val_data_dir'],
        mode='val',
        distance_metric=config.get('distance_metric', 'euclidean'),
        normalize_features=config.get('normalize_features', True),
        scaler=train_dataset.scaler,  # 使用训练集的scaler
        matrix_size=matrix_size
    )
    
    test_dataset = CircuitDataset(
        config['test_data_dir'],
        mode='test',
        distance_metric=config.get('distance_metric', 'euclidean'),
        normalize_features=config.get('normalize_features', True),
        scaler=train_dataset.scaler,  # 使用训练集的scaler
        matrix_size=matrix_size
    )
    
    # 创建配对数据集
    train_pair_dataset = CircuitPairDataset(
        train_dataset,
        num_pairs=config.get('num_pairs_per_epoch', 1000),
        sampling_strategy=config.get('sampling_strategy', 'random')
    )
    
    val_pair_dataset = CircuitPairDataset(
        val_dataset,
        num_pairs=config.get('num_val_pairs', 200)
    )
    
    test_pair_dataset = CircuitPairDataset(
        test_dataset,
        num_pairs=config.get('num_test_pairs', 200)
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_pair_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=collate_circuit_pairs,
        num_workers=config.get('num_workers', 0)
    )
    
    val_loader = DataLoader(
        val_pair_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        collate_fn=collate_circuit_pairs,
        num_workers=config.get('num_workers', 0)
    )
    
    test_loader = DataLoader(
        test_pair_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        collate_fn=collate_circuit_pairs,
        num_workers=config.get('num_workers', 0)
    )
    
    return train_loader, val_loader, test_loader