import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import re

class StreetViewDataset(Dataset):
    """街景序列数据集 - 随机采样版本"""
    
    def __init__(self, data_path, sequence_length=64, num_sequences=500, transform=None, coordinates=None):
        """
        初始化街景数据集
        Args:
            data_path: 包含街景图片的文件夹路径
            sequence_length: 每个样本的帧数
            num_sequences: 要生成的随机序列数量
            transform: 数据转换
            coordinates: 可选的预先过滤的坐标列表(用于训练/测试集划分)
        """
        self.data_path = data_path
        self.sequence_length = sequence_length
        self.transform = transform
        
        # 如果没有提供预先过滤的坐标，就加载所有图片
        if coordinates is None:
            # 获取所有图片文件
            self.image_files = [f for f in os.listdir(data_path) if f.lower().endswith('.jpg')]
            
            if len(self.image_files) == 0:
                raise ValueError(f"在路径 {data_path} 中没有找到任何JPG图片")
            
            # 解析文件名中的坐标信息
            self.coordinates = []
            pattern = r'(\d+)_(\d+)_(\d+)_(-?\d+)\.jpg'
            
            for file in self.image_files:
                match = re.match(pattern, file)
                if match:
                    x, y, heading, pitch = map(int, match.groups())
                    self.coordinates.append((x, y, heading, pitch, file))
            
            print(f"加载了 {len(self.coordinates)} 个街景图片，来源: {data_path}")
        else:
            # 使用预先过滤的坐标列表
            self.coordinates = coordinates
            print(f"使用提供的 {len(self.coordinates)} 个坐标点")
        
        # 生成随机序列
        self.sequences = []
        self._build_random_sequences(num_sequences)
        
        print(f"生成了 {len(self.sequences)} 个随机序列")
    
    def _build_random_sequences(self, num_sequences):
        """构建随机选择的序列"""
        num_images = len(self.coordinates)
        
        if num_images < self.sequence_length:
            raise ValueError(f"图片总数 ({num_images}) 小于所需序列长度 ({self.sequence_length})")
        
        # 生成随机序列
        for _ in range(num_sequences):
            # 随机选择sequence_length个索引
            indices = random.sample(range(num_images), self.sequence_length)
            self.sequences.append(indices)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        # 获取当前序列的索引列表
        sequence_indices = self.sequences[idx]

        # 提取对应的坐标和文件名
        sequence_coords = [self.coordinates[i] for i in sequence_indices]

        # 使用列表推导提取坐标
        coords_array = np.array([(x, y, heading, pitch) for x, y, heading, pitch, _ in sequence_coords])
        filenames = [filename for _, _, _, _, filename in sequence_coords]

        # 获取第一张图的坐标，用于计算差值
        first_coords = coords_array[0]

        # 向量化计算差值 - 一次性减去第一个元素
        diff_coords = coords_array - first_coords

        # 加载图像 (IO操作仍需循环)
        obs_sequence = []
        for filename in filenames:
            img_path = os.path.join(self.data_path, filename)
            img = Image.open(img_path).convert('RGB')
            obs_sequence.append(np.array(img))

        # 转换为numpy数组
        obs_sequence = np.array(obs_sequence)
        pos_sequence = diff_coords  # 相对坐标（差值）
        abs_sequence = coords_array  # 绝对坐标

        # 归一化图像到[-1, 1]范围
        obs_sequence = obs_sequence.astype(np.float32) / 127.5 - 1.0

        # 调整通道顺序从[S, H, W, C]到[S, C, H, W]
        obs_sequence = np.transpose(obs_sequence, (0, 3, 1, 2))

        # 应用变换（如果有）
        if self.transform:
            obs_sequence = self.transform(obs_sequence)

        # 同时返回观察序列、相对坐标和绝对坐标
        return torch.FloatTensor(obs_sequence), torch.LongTensor(pos_sequence), torch.LongTensor(abs_sequence)

def get_streetview_loaders(config):
    """
    创建街景数据加载器，从独立的训练集和测试集文件夹加载数据
    """
    # 计算要生成的总序列数量
    train_sequences = int(0.9 * config.num_sequences) if hasattr(config, 'num_sequences') else 450
    test_sequences = config.num_sequences - train_sequences if hasattr(config, 'num_sequences') else 50
    
    # 检查训练集和测试集路径是否存在
    if hasattr(config, 'train_data_path') and os.path.exists(config.train_data_path) and \
       hasattr(config, 'test_data_path') and os.path.exists(config.test_data_path):
        # 使用分离好的数据目录
        print("使用预先分离的训练集和测试集目录...")
        
        # 创建训练集和测试集
        train_dataset = StreetViewDataset(
            data_path=config.train_data_path,
            sequence_length=config.sequence_length,
            num_sequences=train_sequences
        )
        
        test_dataset = StreetViewDataset(
            data_path=config.test_data_path,
            sequence_length=config.sequence_length,
            num_sequences=test_sequences
        )
        
        print(f"从训练集文件夹读取: {config.train_data_path}")
        print(f"从测试集文件夹读取: {config.test_data_path}")
    else:
        # 向后兼容：使用原来的逻辑进行内存中划分
        print("警告：未找到分离的训练集/测试集目录，使用运行时内存划分...")
        data_path = config.data_path
        
        # 获取所有图片文件
        image_files = [f for f in os.listdir(data_path) if f.lower().endswith('.jpg')]
        
        if len(image_files) == 0:
            raise ValueError(f"在路径 {data_path} 中没有找到任何JPG图片")
        
        # 解析文件名中的坐标信息
        coordinates = []
        pattern = r'(\d+)_(\d+)_(\d+)_(-?\d+)\.jpg'
        
        for file in image_files:
            match = re.match(pattern, file)
            if match:
                x, y, heading, pitch = map(int, match.groups())
                coordinates.append((x, y, heading, pitch, file))
        
        print(f"加载了总共 {len(coordinates)} 个街景图片")
        
        # 先将图片集分割为训练集和测试集（按9:1比例）
        random.seed(config.seed)  # 确保可重现性
        random.shuffle(coordinates)  # 随机打乱
        
        split_point = int(0.9 * len(coordinates))
        train_coordinates = coordinates[:split_point]
        test_coordinates = coordinates[split_point:]
        
        print(f"划分为 {len(train_coordinates)} 个训练图片和 {len(test_coordinates)} 个测试图片")
        
        # 创建训练集和测试集（都使用相同的data_path，但提供不同的坐标列表）
        train_dataset = StreetViewDataset(
            data_path=config.data_path,
            sequence_length=config.sequence_length,
            num_sequences=train_sequences,
            coordinates=train_coordinates
        )
        
        test_dataset = StreetViewDataset(
            data_path=config.data_path,
            sequence_length=config.sequence_length,
            num_sequences=test_sequences,
            coordinates=test_coordinates
        )
    
    print(f"训练集: {len(train_dataset)} 个序列，测试集: {len(test_dataset)} 个序列")
    print(f"使用随机种子 {config.seed} 进行数据集划分，确保可重复性")
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )
    
    return train_loader, test_loader