import torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class AwA2Dataset(Dataset):
    def __init__(self, data_dir, data_frame, selected_attributes, transform):
        super().__init__()
        self.data_dir = data_dir
        self.data_frame = data_frame
        self.selected_attributes = selected_attributes
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, index):
        sample = self.data_frame.iloc[index, :]

        image_path = f"{self.data_dir}/image/Animals_with_Attributes2/JPEGImages/{sample.image_path}"
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)

        # Extract only the selected attributes (concepts)
        attributes = torch.FloatTensor([sample[attr] for attr in self.selected_attributes])
        
        # If there are soft attributes (e.g., from a model prediction)
        if 'soft_attributes' in sample:
            # 假设soft_attributes的顺序与selected_attributes相同
            soft_attributes = torch.FloatTensor(sample['soft_attributes'])
        else:
            soft_attributes = torch.zeros_like(attributes)

        return image, sample.class_id, attributes, soft_attributes, sample.image_path


class AwA2DataModule_selected(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, train_with_c_gt=True, concept_weight=True, arch=None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_with_c_gt = train_with_c_gt
        self.concept_weight = concept_weight
        
        attributes_5 = [
            #  水生环境
            "swims",

            #  食性和牙齿
            "meatteeth",

            #  陆地运动
            "quadrapedal",

            #  体型与防御
            "big", 
            "fierce"
        ]
        
        attributes_10 = [
            #  水生环境
            "swims", "water",

            #  食性和牙齿
            "meatteeth", "meat", "chewteeth",

            #  陆地运动
            "quadrapedal", "fast",

            #  栖息地
            "forest",

            #  体型与防御
            "big", "fierce"
        ]

        attributes_15 = [
            #  水生环境
            "swims", "water", "ocean", "coastal", "flippers",

            #  食性和牙齿
            "meatteeth", "meat", "chewteeth",

            #  陆地运动
            "quadrapedal", "fast",

            #  栖息地和习性
            "forest", "tree", "nocturnal",

            #  体型与防御
            "big", "fierce"
        ]

        attributes_20 = [
            # 水生环境适应综合体 - 捕捉水生动物的完整特征集
            "swims", "water", "ocean", "coastal", "flippers",
            
            # 食性和牙齿形态关系 - 反映进化上的饮食适应
            "meatteeth", "meat", "chewteeth", 
            
            # 陆地运动方式特征组 - 形成运动能力的多元依赖
            "quadrapedal", "fast", "agility", "hooves",
            
            # 栖息地和生活习性 - 捕捉环境适应关系
            "forest", "tree", "nocturnal",
            
            # 社会行为和智能特征 - 反映社交和认知能力关联
            "group", "solitary", "smart",
            
            # 体型和防御特征 - 形成生存策略关联
            "big", "fierce"
        ]
        
        attributes_25 = [
            #  水生环境
            "swims", "water", "ocean", "coastal", "flippers",

            #  食性和牙齿
            "meatteeth", "meat", "chewteeth",

            #  陆地运动
            "quadrapedal", "fast", "agility", "hooves",

            #  栖息地和习性
            "forest", "tree", "nocturnal",

            #  社会与智能
            "group", "solitary", "smart",

            #  体型与防御
            "big", "fierce",

            #  捕食/防御扩展
            "hunter", "stalker",   # 补充捕猎行为
            "horns", "claws", "tusks"  # 更多防御/进攻器官
        ]
        attributes_30 = [
            #  水生环境
            "swims", "water", "ocean", "coastal", "flippers",

            #  食性和牙齿
            "meatteeth", "meat", "chewteeth",

            #  陆地运动
            "quadrapedal", "fast", "agility", "hooves",

            #  栖息地和习性
            "forest", "tree", "nocturnal",

            #  社会与智能
            "group", "solitary", "smart",

            #  体型与防御
            "big", "fierce",

            #  捕食/防御扩展
            "hunter", "stalker",
            "horns", "claws", "tusks",

            #  新增混合特征
            "grazer",       # 另一种食性
            "plains",       # 栖息地
            "active",       # 活动度
            "domestic",     # 是否驯化
            "strong"        # 体能素质
        ]

        # #20 个属性
        self.selected_attributes_names = attributes_20
        # Standard ImageNet normalization
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        
        self.aug_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        self.noaug_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    
    def prepare_data(self):
        # Load attribute names from predicates.txt
        with open(f"{self.data_dir}/image/Animals_with_Attributes2/predicates.txt", 'r') as f:
            all_attribute_names = []
            for line in f:
                # Format is typically "1   black" or similar
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    all_attribute_names.append(parts[1])
        
        # 获取选定属性的索引
        self.selected_indices = []
        self.selected_attributes = []
        
        for attr_name in self.selected_attributes_names:
            if attr_name in all_attribute_names:
                attr_idx = all_attribute_names.index(attr_name)
                self.selected_indices.append(attr_idx)
                self.selected_attributes.append(attr_name)
            else:
                print(f"Warning: Attribute '{attr_name}' not found in dataset!")
        
        if len(self.selected_attributes) != len(self.selected_attributes_names):
            print(f"Only {len(self.selected_attributes)} out of {len(self.selected_attributes_names)} requested attributes were found.")
        
        # Load class names
        with open(f"{self.data_dir}/image/Animals_with_Attributes2/classes.txt", 'r') as f:
            class_names = []
            for line in f:
                # Format is typically "1   antelope" or similar
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    class_names.append(parts[1])
        
        # Load attribute matrix (class × attribute)
        full_attribute_matrix = np.loadtxt(f"{self.data_dir}/image/Animals_with_Attributes2/predicate-matrix-binary.txt", 
                                     delimiter=' ', skiprows=0)
        
        # 只保留选定的属性列
        attribute_matrix = full_attribute_matrix[:, self.selected_indices]
        
        # Load image paths and labels
        image_paths = []
        class_ids = []
        
        for class_id, class_name in enumerate(class_names):
            class_dir = f"{self.data_dir}/image/Animals_with_Attributes2/JPEGImages/{class_name}"
            import os
            for img_file in os.listdir(class_dir):
                if img_file.endswith(('.jpg', '.jpeg', '.png')):
                    image_paths.append(f"{class_name}/{img_file}")
                    class_ids.append(class_id)
        
        # Create DataFrame with image paths and class ids
        df = pd.DataFrame({'image_path': image_paths, 'class_id': class_ids})
        
        # Add attribute values for each image based on its class
        for i, attr_name in enumerate(self.selected_attributes):
            df[attr_name] = df['class_id'].apply(lambda x: attribute_matrix[x, i])
        
        self.attribute_num = len(self.selected_attributes)
        self.attribute_names = self.selected_attributes
        self.class_names = class_names
        self.df = df
        
        # 打印选定属性的统计信息
        attribute_counts = df[self.selected_attributes].sum()
        print(f"Selected {self.attribute_num} attributes:")
        for attr_name, count in attribute_counts.items():
            percentage = (count / len(df)) * 100
            print(f"  - {attr_name}: {count} images ({percentage:.2f}%)")
        
        # 计算属性相关性矩阵
        correlation_matrix = df[self.selected_attributes].corr()
        # 打印一些高相关性的属性对
        correlation_pairs = []
        for i in range(len(self.selected_attributes)):
            for j in range(i+1, len(self.selected_attributes)):
                attr1 = self.selected_attributes[i]
                attr2 = self.selected_attributes[j]
                corr = correlation_matrix.loc[attr1, attr2]
                correlation_pairs.append((attr1, attr2, corr))
        
        # 按相关性排序并打印前10个
        correlation_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
        print("\nTop 10 correlated attribute pairs:")
        for attr1, attr2, corr in correlation_pairs[:10]:
            print(f"  - {attr1} & {attr2}: {corr:.3f}")

    @property
    def attribute_list(self):
        return self.attribute_names

    @property
    def imbalance_weight(self):
        # Calculate attribute imbalance weights
        if not hasattr(self, 'df'):
            self.prepare_data()
            
        count = self.df[self.attribute_names].sum().values
        weight = torch.tensor(len(self.df) / count)
        
        if not self.concept_weight:
            weight = torch.ones_like(weight)
            
        return weight

    def setup(self, stage=None):
        if not hasattr(self, 'df'):
            self.prepare_data()
        
        # Split data into train/val/test
        train_val_df, test_df = train_test_split(
            self.df, 
            test_size=0.2, 
            stratify=self.df['class_id'],
            random_state=42
        )
        
        train_df, val_df = train_test_split(
            train_val_df,
            test_size=0.25,
            stratify=train_val_df['class_id'],
            random_state=42
        )
        
        # Create datasets
        self.train_dataset = AwA2Dataset(
            self.data_dir, 
            train_df, 
            self.selected_attributes, 
            self.aug_transform
        )
        
        self.val_dataset = AwA2Dataset(
            self.data_dir, 
            val_df, 
            self.selected_attributes, 
            self.noaug_transform
        )
        
        self.test_dataset = AwA2Dataset(
            self.data_dir, 
            test_df, 
            self.selected_attributes, 
            self.noaug_transform
        )
        
        print(f"Dataset split: {len(train_df)} train, {len(val_df)} validation, {len(test_df)} test")

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size,
            shuffle=True, 
            num_workers=4
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size,
            shuffle=False, 
            num_workers=4
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            batch_size=self.batch_size,
            shuffle=False, 
            num_workers=4
        )

# Example usage:
# if __name__ == "__main__":
#     datamodule = AwA2DataModule_selected(data_dir='./data/AWA2', batch_size=128)
#     datamodule.prepare_data()
#     datamodule.setup()
#     train_loader = datamodule.train_dataloader()
    
#     
#     for images, class_ids, attributes, soft_attributes, image_paths in train_loader:
#         print(f"Images shape: {images.shape}")
#         print(f"Class IDs shape: {class_ids.shape}")
#         print(f"Attributes shape: {attributes.shape}")  
#         print(f"Soft attributes shape: {soft_attributes.shape}")
#         print(f"Image paths count: {len(image_paths)}")
#         break  