import torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class AwA2Dataset(Dataset):
    def __init__(self, data_dir, data_frame, attribute_num, transform):
        super().__init__()
        self.data_dir = data_dir
        self.data_frame = data_frame
        self.attribute_num = attribute_num
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, index):
        sample = self.data_frame.iloc[index, :]

        image_path = f"{self.data_dir}/image/Animals_with_Attributes2/JPEGImages/{sample.image_path}"
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)

        # Extract attributes (concepts)
        attributes = torch.FloatTensor(list(sample.iloc[2 : self.attribute_num + 2]))
        
        # If there are soft attributes (e.g., from a model prediction)
        if len(sample) > self.attribute_num + 2:
            soft_attributes = torch.FloatTensor(list(sample.iloc[self.attribute_num + 2 :]))
        else:
            soft_attributes = torch.zeros_like(attributes)

        return image, sample.class_id, attributes, soft_attributes, sample.image_path


class AwA2DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, train_with_c_gt=True, concept_weight=True, arch=None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_with_c_gt = train_with_c_gt
        self.concept_weight = concept_weight
        
        # Standard ImageNet normalization
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        
        self.aug_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        self.noaug_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    
    def prepare_data(self):
        # Load attribute names from predicates.txt
        with open(f"{self.data_dir}/image/Animals_with_Attributes2/predicates.txt", 'r') as f:
            attribute_names = []
            for line in f:
                # Format is typically "1   black" or similar
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    attribute_names.append(parts[1])
        
        # Load class names
        with open(f"{self.data_dir}/image/Animals_with_Attributes2/classes.txt", 'r') as f:
            class_names = []
            for line in f:
                # Format is typically "1   antelope" or similar
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    class_names.append(parts[1])
        
        # Load attribute matrix (class × attribute)
        attribute_matrix = np.loadtxt(f"{self.data_dir}/image/Animals_with_Attributes2/predicate-matrix-binary.txt", 
                                     delimiter=' ', skiprows=0)
        
        # Load image paths and labels
        image_paths = []
        class_ids = []
        
        for class_id, class_name in enumerate(class_names):
            class_dir = f"{self.data_dir}/image/Animals_with_Attributes2/JPEGImages/{class_name}"
            import os
            for img_file in os.listdir(class_dir):
                if img_file.endswith(('.jpg', '.jpeg', '.png')):
                    image_paths.append(f"{class_name}/{img_file}")
                    class_ids.append(class_id)
        
        # Create DataFrame with image paths and class ids
        df = pd.DataFrame({'image_path': image_paths, 'class_id': class_ids})
        
        # Add attribute values for each image based on its class
        for i, attr_name in enumerate(attribute_names):
            df[attr_name] = df['class_id'].apply(lambda x: attribute_matrix[x, i])
        
        self.attribute_num = len(attribute_names)
        self.attribute_names = attribute_names
        self.class_names = class_names
        self.df = df
        
        # If we have soft attribute predictions
        if 'soft_attributes' in df.columns:
            self.has_soft_attributes = True
        else:
            self.has_soft_attributes = False

    @property
    def attribute_list(self):
        return self.attribute_names

    @property
    def imbalance_weight(self):
        # Calculate attribute imbalance weights
        if not hasattr(self, 'df'):
            self.prepare_data()
            
        count = self.df[self.attribute_names].sum().values
        weight = torch.tensor(len(self.df) / count)
        
        if not self.concept_weight:
            weight = torch.ones_like(weight)
            
        return weight

    def setup(self, stage=None):
        if not hasattr(self, 'df'):
            self.prepare_data()
        
        # Split data into train/val/test
        train_val_df, test_df = train_test_split(
            self.df, 
            test_size=0.2, 
            stratify=self.df['class_id'],
            random_state=42
        )
        
        train_df, val_df = train_test_split(
            train_val_df,
            test_size=0.25,
            stratify=train_val_df['class_id'],
            random_state=42
        )
        
        # Create datasets
        self.train_dataset = AwA2Dataset(
            self.data_dir, 
            train_df, 
            self.attribute_num, 
            self.aug_transform
        )
        
        self.val_dataset = AwA2Dataset(
            self.data_dir, 
            val_df, 
            self.attribute_num, 
            self.noaug_transform
        )
        
        self.test_dataset = AwA2Dataset(
            self.data_dir, 
            test_df, 
            self.attribute_num, 
            self.noaug_transform
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size,
            shuffle=True, 
            num_workers=4
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size,
            shuffle=False, 
            num_workers=4
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            batch_size=self.batch_size,
            shuffle=False, 
            num_workers=4
        )

# Example usage:
# datamodule = AwA2DataModule(data_dir='./data/AWA2', batch_size=128)
# datamodule.prepare_data()
# datamodule.setup()
# train_loader = datamodule.train_dataloader()