import os
import pandas as pd
import torch
import torch.utils.data as data
from transformers import AutoTokenizer
from functools import lru_cache
import numpy as np


class EvidenceDataset(data.Dataset):
    def __init__(self, cfg, split="train", csv_path=None, preprocess=True):
        self.split = split
        self.preprocess = preprocess

        if csv_path is None:
            csv_path = getattr(cfg.data, ' ',
                               " ")

        self.df = pd.read_csv(csv_path)

        if 'evidences' not in self.df.columns:
            raise ValueError(" ")

        if 'split' in self.df.columns:
            self.df = self.df[self.df['split'] == split]

        bert_path = cfg.model.text.bert_type if hasattr(cfg.model.text,
                                                        'bert_type') else " "
        self.tokenizer = AutoTokenizer.from_pretrained(bert_path)
        self.max_length = getattr(cfg.data, 'max_length', 128)

        if self.preprocess:
            print(f"Preprocessing {len(self.df)} samples for {split} split...")
            self._preprocess_all()
            print(f"Preprocessing completed!")

    def _preprocess_all(self):
        self.preprocessed_data = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            evidences_str = str(row['evidences']) if pd.notna(row['evidences']) else ""

            evidences = [ev.strip() for ev in evidences_str.split('｜') if ev.strip()] if evidences_str else []
            if len(evidences) == 0:
                evidences = [""]

            encoded_batch = self.tokenizer(
                evidences,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            # 转换为numpy数组以节省内存
            input_ids = encoded_batch['input_ids'].numpy()
            attention_mask = encoded_batch['attention_mask'].numpy()
            token_type_ids = encoded_batch.get('token_type_ids', torch.zeros_like(encoded_batch['input_ids'])).numpy()
            
            self.preprocessed_data.append({
                'evidences': evidences,
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'token_type_ids': token_type_ids,
                'num_evidences': len(evidences),
                'report_id': row.get('report_id', idx)
            })

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if self.preprocess and hasattr(self, 'preprocessed_data'):
            data = self.preprocessed_data[idx]
            evidence_encodings = []
            for i in range(data['num_evidences']):
                evidence_encodings.append({
                    'input_ids': torch.from_numpy(data['input_ids'][i]),
                    'attention_mask': torch.from_numpy(data['attention_mask'][i]),
                    'token_type_ids': torch.from_numpy(data['token_type_ids'][i])
                })
            
            return {
                'evidences': data['evidences'],
                'evidence_encodings': evidence_encodings,
                'num_evidences': data['num_evidences'],
                'report_id': data['report_id']
            }
        else:
            row = self.df.iloc[idx]
            evidences_str = str(row['evidences']) if pd.notna(row['evidences']) else ""

            evidences = [ev.strip() for ev in evidences_str.split('｜') if ev.strip()] if evidences_str else []
            if len(evidences) == 0:
                evidences = [""]

            encoded_batch = self.tokenizer(
                evidences,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            evidence_encodings = []
            for i in range(len(evidences)):
                evidence_encodings.append({
                    'input_ids': encoded_batch['input_ids'][i],
                    'attention_mask': encoded_batch['attention_mask'][i],
                    'token_type_ids': encoded_batch.get('token_type_ids', torch.zeros_like(encoded_batch['input_ids']))[i]
                })

            return {
                'evidences': evidences,
                'evidence_encodings': evidence_encodings,
                'num_evidences': len(evidences),
                'report_id': row.get('report_id', idx)
            }


def evidence_collate_fn(batch):
    batch_size = len(batch)
    max_num_evidences = max([item['num_evidences'] for item in batch])
    max_length = batch[0]['evidence_encodings'][0]['input_ids'].shape[0]

    input_ids_list = []
    attention_mask_list = []
    token_type_ids_list = []
    num_evidences = torch.zeros(batch_size, dtype=torch.long)

    for i, item in enumerate(batch):
        num_ev = item['num_evidences']
        num_evidences[i] = num_ev

        item_input_ids = []
        item_attention_mask = []
        item_token_type_ids = []
        
        for j in range(num_ev):
            enc = item['evidence_encodings'][j]
            item_input_ids.append(enc['input_ids'])
            item_attention_mask.append(enc['attention_mask'])
            item_token_type_ids.append(enc['token_type_ids'])

        if num_ev < max_num_evidences:
            padding = torch.zeros(max_num_evidences - num_ev, max_length, dtype=torch.long)
            item_input_ids.extend([padding] * (max_num_evidences - num_ev))
            item_attention_mask.extend([padding] * (max_num_evidences - num_ev))
            item_token_type_ids.extend([padding] * (max_num_evidences - num_ev))

        input_ids_list.append(torch.stack(item_input_ids))
        attention_mask_list.append(torch.stack(item_attention_mask))
        token_type_ids_list.append(torch.stack(item_token_type_ids))

    input_ids = torch.stack(input_ids_list)
    attention_mask = torch.stack(attention_mask_list)
    token_type_ids = torch.stack(token_type_ids_list)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'token_type_ids': token_type_ids,
        'num_evidences': num_evidences,
        'evidences': [item['evidences'] for item in batch],
        'report_ids': [item['report_id'] for item in batch]
    }

