import os
import pandas as pd
import torch
import torch.utils.data as data
import cv2
import numpy as np
from PIL import Image
from transformers import AutoTokenizer
from torchvision import transforms
import random


class Stage3RelationDataset(data.Dataset):

    def __init__(self, cfg, split="train", csv_path=None, stage2_sample_ratio=0.1, paired_csv_path=None):
        self.split = split
        self.stage2_sample_ratio = stage2_sample_ratio
        if paired_csv_path is None:
            paired_csv_path = getattr(cfg.data, 'paired_csv_path', None)

        if paired_csv_path is None or not os.path.exists(paired_csv_path):
            raise ValueError(f"  {paired_csv_path}.  ")

        print(f" {paired_csv_path}")
        paired_df = pd.read_csv(paired_csv_path)

        if 'path' not in paired_df.columns:
            raise ValueError("Paired CSV file must contain 'path' column")
        if 'evidences' not in paired_df.columns:
            raise ValueError("Paired CSV file must contain 'evidences' column")

        if 'split' in paired_df.columns:
            paired_df = paired_df[paired_df['split'] == split]

        self.paired_df = paired_df.reset_index(drop=True)


        if csv_path is None:
            csv_path = getattr(cfg.data, 'evidence_csv_path',
                               " ")

        print(f" {csv_path}")
        self.df = pd.read_csv(csv_path)

        if 'path' not in self.df.columns:
            raise ValueError("CSV file must contain 'path' column")
        if 'evidences' not in self.df.columns:
            raise ValueError("CSV file must contain 'evidences' column")

        if 'split' in self.df.columns:
            self.df = self.df[self.df['split'] == split]


        if 'path' in self.paired_df.columns and 'evidences' in self.paired_df.columns:
            paired_keys = set(zip(
                self.paired_df['path'].astype(str),
                self.paired_df['evidences'].astype(str)
            ))

            self.df['_key'] = self.df['path'].astype(str) + '|||' + self.df['evidences'].astype(str)
            paired_keys_str = {p + '|||' + e for p, e in paired_keys}

            mask = ~self.df['_key'].isin(paired_keys_str)
            self.unpaired_df = self.df[mask].drop(columns=['_key']).reset_index(drop=True)

            print(f" {len(self.paired_df)},  {len(self.unpaired_df)}")
        else:
            print(" ")
            self.unpaired_df = self.df.reset_index(drop=True)

        bert_path = cfg.model.text.bert_type if hasattr(cfg.model.text,
                                                        'bert_type') else " "
        self.tokenizer = AutoTokenizer.from_pretrained(bert_path)
        self.max_length = getattr(cfg.data, 'max_length', 128)

        self.imsize = getattr(cfg.data.image, 'imsize', 224)
        self.transform = self._get_transform()

        self.pair_ratio = getattr(cfg.data, 'pair_ratio', 0.3)

    def _get_transform(self):
        if self.split == "train":
            transform = transforms.Compose([
                transforms.Resize((self.imsize, self.imsize)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize((self.imsize, self.imsize)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        return transform

    def read_image(self, img_path):
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        try:
            x = cv2.imread(str(img_path), 0)
            if x is None:
                img = Image.open(img_path).convert('L')
                x = np.array(img)
            else:
                if len(x.shape) == 2:
                    x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
                elif x.shape[2] == 4:
                    x = cv2.cvtColor(x, cv2.COLOR_BGRA2RGB)
                elif x.shape[2] == 3:
                    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
        except Exception as e:
            img = Image.open(img_path).convert('RGB')
            x = np.array(img)

        img = Image.fromarray(x).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.paired_df)

    def __getitem__(self, idx):
        is_paired = random.random() < self.pair_ratio

        if is_paired and idx < len(self.paired_df):
            row = self.paired_df.iloc[idx]
            img_path = str(row['path'])
            image = self.read_image(img_path)

            evidences_str = str(row['evidences']) if pd.notna(row['evidences']) else ""
            evidences = [ev.strip() for ev in evidences_str.split('｜') if ev.strip()] if evidences_str else []
            if len(evidences) == 0:
                evidences = [""]

            is_paired_sample = True
            pair_id = idx
        else:
            if random.random() < 0.5:
                unpaired_idx = random.randint(0, len(self.unpaired_df) - 1)
                row = self.unpaired_df.iloc[unpaired_idx]
                img_path = str(row['path'])
                image = self.read_image(img_path)
                evidences = []
                is_paired_sample = False
                pair_id = -1
            else:
                unpaired_idx = random.randint(0, len(self.unpaired_df) - 1)
                row = self.unpaired_df.iloc[unpaired_idx]
                img_path = None
                image = None
                evidences_str = str(row['evidences']) if pd.notna(row['evidences']) else ""
                evidences = [ev.strip() for ev in evidences_str.split('｜') if ev.strip()] if evidences_str else []
                if len(evidences) == 0:
                    evidences = [""]
                is_paired_sample = False
                pair_id = -1

        evidence_encodings = []
        for evidence in evidences:
            encoded = self.tokenizer(
                evidence,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            token_type_ids = encoded.get('token_type_ids', torch.zeros_like(encoded['input_ids'])).squeeze(0)
            evidence_encodings.append({
                'input_ids': encoded['input_ids'].squeeze(0),
                'attention_mask': encoded['attention_mask'].squeeze(0),
                'token_type_ids': token_type_ids
            })

        return {
            'image': image,
            'img_path': img_path if img_path else None,
            'evidences': evidences,
            'evidence_encodings': evidence_encodings,
            'num_evidences': len(evidences),
            'is_paired': is_paired_sample,
            'pair_id': pair_id,
            'report_id': row.get('report_id', idx)
        }


def stage3_relation_collate_fn(batch):

    images_list = []
    img_indices = []
    text_indices = []
    paired_pairs = []

    for i, item in enumerate(batch):
        if item['image'] is not None:
            images_list.append(item['image'])
            img_indices.append(i)
        if len(item['evidences']) > 0:
            text_indices.append(i)
        if item['is_paired'] and item['image'] is not None and len(item['evidences']) > 0:
            # 这是一个配对样本
            img_idx = len(images_list) - 1 if item['image'] is not None else -1
            text_idx = len(text_indices) - 1 if len(item['evidences']) > 0 else -1
            if img_idx >= 0 and text_idx >= 0:
                paired_pairs.append((img_idx, text_idx))

    images = torch.stack(images_list) if len(images_list) > 0 else None
    num_images = len(images_list)
    num_texts = len(text_indices)

    if num_texts > 0:
        max_num_evidences = max([len(batch[i]['evidences']) for i in text_indices])
        if max_num_evidences == 0:
            max_num_evidences = 1  # 至少为1
        max_length = batch[text_indices[0]]['evidence_encodings'][0]['input_ids'].shape[0] if len(
            batch[text_indices[0]]['evidence_encodings']) > 0 else 128

        input_ids = torch.zeros(num_texts, max_num_evidences, max_length, dtype=torch.long)
        attention_mask = torch.zeros(num_texts, max_num_evidences, max_length, dtype=torch.long)
        token_type_ids = torch.zeros(num_texts, max_num_evidences, max_length, dtype=torch.long)
        num_evidences = torch.zeros(num_texts, dtype=torch.long)

        for j, text_idx in enumerate(text_indices):
            item = batch[text_idx]
            num_ev = item['num_evidences']
            num_evidences[j] = num_ev
            if num_ev > 0:
                for k in range(min(num_ev, max_num_evidences)):
                    enc = item['evidence_encodings'][k]
                    input_ids[j, k] = enc['input_ids']
                    attention_mask[j, k] = enc['attention_mask']
                    token_type_ids[j, k] = enc['token_type_ids']
    else:
        input_ids = None
        attention_mask = None
        token_type_ids = None
        num_evidences = None
        max_num_evidences = 0

    Y = torch.zeros(num_images, num_texts, dtype=torch.float)
    for img_idx, text_idx in paired_pairs:
        if img_idx < num_images and text_idx < num_texts:
            Y[img_idx, text_idx] = 1.0

    return {
        'images': images,
        'img_indices': img_indices,
        'text_indices': text_indices,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'token_type_ids': token_type_ids,
        'num_evidences': num_evidences,
        'paired_matrix': Y,
        'num_images': num_images,
        'num_texts': num_texts,
        'evidences': [batch[i]['evidences'] for i in text_indices] if num_texts > 0 else [],
        'img_paths': [batch[i]['img_path'] for i in img_indices] if num_images > 0 else []
    }

