import torch
import pickle
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torchvision.transforms as transforms
from data.CUB200.cub_loader import generate_data

from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

from models.evi_clm import Evi_CLM
from utils import _convert_image_to_rgb
import scipy.linalg


class CUBDatamodules(pl.LightningDataModule):
    def __init__(self, seed,data_dir, batch_size, train_with_c_gt, concept_weight, arch):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_with_c_gt = train_with_c_gt
        self.concept_weight = concept_weight
        self.config = {'batch_size': self.batch_size,
                'num_workers': 4,
                'weight_loss': True,  # 使用类别不平衡权重
                'sampling_percent': 1,  # 
                'sampling_groups': True,  # 按组进行采样
                'seed': seed,  # 随机种子
                }
        
    
    def prepare_data(self):
        pass

    @property
    def imbalance_weight(self):
        return torch.tensor(self.imbalance)

    def setup(self, stage):
        # 调用 generate_data 函数
        train_dl, val_dl, test_dl, imbalance, dataset_vars = generate_data(
            seed=self.config['seed'],
            config=self.config,
            root_dir='./data/CUB200/',
            output_dataset_vars=True,  # 获取额外的数据集信息
            rerun=True  # 使用缓存的采样结果（如果存在）
        )
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.test_dl = test_dl
        self.imbalance = imbalance
        self.dataset_vars = dataset_vars
    def train_dataloader(self):
        return self.train_dl
    def val_dataloader(self):
        return self.val_dl
    def test_dataloader(self):
        return self.test_dl


