import os
import pandas as pd
import torch
import torch.utils.data as data
import cv2
import numpy as np
from PIL import Image
from transformers import AutoTokenizer
from torchvision import transforms


class ImageTextPairDataset(data.Dataset):

    def __init__(self, cfg, split="train", csv_path=None, sample_ratio=0.1, preprocess_text=True, paired_csv_path=None):
        self.split = split
        self.sample_ratio = sample_ratio
        self.preprocess_text = preprocess_text

        if paired_csv_path is None:
            paired_csv_path = getattr(cfg.data, 'paired_csv_path', None)

        if paired_csv_path is None or not os.path.exists(paired_csv_path):
            raise ValueError(f" ")

        print(f" {paired_csv_path}")
        self.df = pd.read_csv(paired_csv_path)

        if 'path' not in self.df.columns:
            raise ValueError("CSV file must contain 'path' column")
        if 'evidences' not in self.df.columns:
            raise ValueError("CSV file must contain 'evidences' column")

        if 'split' in self.df.columns:
            self.df = self.df[self.df['split'] == split]

        self.df = self.df.reset_index(drop=True)

        bert_path = cfg.model.text.bert_type if hasattr(cfg.model.text,
                                                        'bert_type') else " "
        self.tokenizer = AutoTokenizer.from_pretrained(bert_path)
        self.max_length = getattr(cfg.data, 'max_length', 128)

        self.imsize = getattr(cfg.data.image, 'imsize', 224)
        self.transform = self._get_transform()

        if self.preprocess_text:
            print(f"Preprocessing text tokenization for {len(self.df)} samples in {split} split...")
            self._preprocess_text()
            print(f"Text preprocessing completed!")

    def _get_transform(self):
        if self.split == "train":
            transform = transforms.Compose([
                transforms.Resize((self.imsize, self.imsize)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize((self.imsize, self.imsize)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        return transform

    def read_image(self, img_path):
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        # 尝试读取图像
        try:
            x = cv2.imread(str(img_path), 0)  # 读取为灰度图
            if x is None:
                img = Image.open(img_path).convert('L')
                x = np.array(img)
            else:
                x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB) if len(x.shape) == 2 else x
                if len(x.shape) == 2:
                    x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
                elif x.shape[2] == 4:
                    x = cv2.cvtColor(x, cv2.COLOR_BGRA2RGB)
                elif x.shape[2] == 3:
                    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
        except Exception as e:
            img = Image.open(img_path).convert('RGB')
            x = np.array(img)

        img = Image.fromarray(x).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.df)

    def _preprocess_text(self):
        self.preprocessed_text = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            evidences_str = str(row['evidences']) if pd.notna(row['evidences']) else ""
            evidences = [ev.strip() for ev in evidences_str.split('｜') if ev.strip()] if evidences_str else []
            if len(evidences) == 0:
                evidences = [""]
            encoded_batch = self.tokenizer(
                evidences,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            input_ids = encoded_batch['input_ids'].numpy()
            attention_mask = encoded_batch['attention_mask'].numpy()
            token_type_ids = encoded_batch.get('token_type_ids', torch.zeros_like(encoded_batch['input_ids'])).numpy()

            self.preprocessed_text.append({
                'evidences': evidences,
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'token_type_ids': token_type_ids,
                'num_evidences': len(evidences)
            })

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = str(row['path'])
        image = self.read_image(img_path)
        if self.preprocess_text and hasattr(self, 'preprocessed_text'):
            data = self.preprocessed_text[idx]
            # 转换回torch tensor
            evidence_encodings = []
            for i in range(data['num_evidences']):
                evidence_encodings.append({
                    'input_ids': torch.from_numpy(data['input_ids'][i]),
                    'attention_mask': torch.from_numpy(data['attention_mask'][i]),
                    'token_type_ids': torch.from_numpy(data['token_type_ids'][i])
                })
            evidences = data['evidences']
        else:
            evidences_str = str(row['evidences']) if pd.notna(row['evidences']) else ""
            evidences = [ev.strip() for ev in evidences_str.split('｜') if ev.strip()] if evidences_str else []
            if len(evidences) == 0:
                evidences = [""]

            encoded_batch = self.tokenizer(
                evidences,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            evidence_encodings = []
            for i in range(len(evidences)):
                evidence_encodings.append({
                    'input_ids': encoded_batch['input_ids'][i],
                    'attention_mask': encoded_batch['attention_mask'][i],
                    'token_type_ids': encoded_batch.get('token_type_ids', torch.zeros_like(encoded_batch['input_ids']))[
                        i]
                })

        return {
            'image': image,
            'img_path': img_path,
            'evidences': evidences,
            'evidence_encodings': evidence_encodings,
            'num_evidences': len(evidences),
            'report_id': row.get('report_id', idx)
        }


def image_text_pair_collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    img_paths = [item['img_path'] for item in batch]
    report_ids = [item['report_id'] for item in batch]

    max_num_evidences = max([item['num_evidences'] for item in batch])
    batch_size = len(batch)
    max_length = batch[0]['evidence_encodings'][0]['input_ids'].shape[0]

    input_ids_list = []
    attention_mask_list = []
    token_type_ids_list = []
    num_evidences = torch.zeros(batch_size, dtype=torch.long)

    for i, item in enumerate(batch):
        num_ev = item['num_evidences']
        num_evidences[i] = num_ev
        item_input_ids = []
        item_attention_mask = []
        item_token_type_ids = []

        for j in range(num_ev):
            enc = item['evidence_encodings'][j]
            item_input_ids.append(enc['input_ids'])
            item_attention_mask.append(enc['attention_mask'])
            item_token_type_ids.append(enc['token_type_ids'])
        if num_ev < max_num_evidences:
            padding = torch.zeros(max_num_evidences - num_ev, max_length, dtype=torch.long)
            item_input_ids.extend([padding] * (max_num_evidences - num_ev))
            item_attention_mask.extend([padding] * (max_num_evidences - num_ev))
            item_token_type_ids.extend([padding] * (max_num_evidences - num_ev))

        input_ids_list.append(torch.stack(item_input_ids))
        attention_mask_list.append(torch.stack(item_attention_mask))
        token_type_ids_list.append(torch.stack(item_token_type_ids))

    input_ids = torch.stack(input_ids_list)
    attention_mask = torch.stack(attention_mask_list)
    token_type_ids = torch.stack(token_type_ids_list)

    return {
        'images': images,
        'img_paths': img_paths,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'token_type_ids': token_type_ids,
        'num_evidences': num_evidences,
        'evidences': [item['evidences'] for item in batch],
        'report_ids': report_ids
    }

