import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from lib.data.utils import split_data
from lib.data.base import VLInputs
from lib.data.base import VLDataset

data_config = dict(
    cub200="abspath-to-folder/ICLR-Datasets/cub200/table.csv",
    flowers="abspath-to-folder/ICLR-Datasets/flowers/table.csv",
    dvm_cars="abspath-to-folder/ICLR-Datasets/dvm_cars/table.csv"
)

###
        
def collate_fn(data):
    pixel_values = []
    input_ids = []
    attention_mask = []
    targets = []
    max_length = max([len(d[0]["attention_mask"]) for d in data])
    for d in data:
        pixel_values.append(torch.tensor(d[0]["pixel_values"][None,:]).float())
        input_ids.append(torch.tensor(d[0]["input_ids"] + [0]*(max_length - len(d[0]["input_ids"])))[None,:].long())
        attention_mask.append(torch.tensor(d[0]["attention_mask"] + [0]*(max_length - len(d[0]["attention_mask"])))[None,:].long())
        targets.append(d[1])
    return VLInputs(
        pixel_values=torch.cat(pixel_values, dim=0),
        input_ids=torch.cat(input_ids, dim=0),
        attention_mask=torch.cat(attention_mask, dim=0)
    ), torch.tensor(targets).long()
        
###

class VLExperienceStream:
    def __init__(self,
        data_path: str,
        n_experiences: int,
        image_processor: object,
        text_processor: object,
        seed: int = 42
    ):
        self.n_experiences = n_experiences
        df = pd.read_csv(data_path)
        experiences = split_data(df[["image_path", "text"]].values, df.label.values, n_experiences, seed)
        
        self.train_stream = []
        self.test_stream = []
        for i in range(n_experiences):
            data_experience = experiences[i]
            x_train, x_test, y_train, y_test = train_test_split(data_experience["inputs"], data_experience["targets"], stratify=data_experience["targets"], random_state=seed)
            train_dataset = VLDataset(x_train, y_train, image_processor, text_processor)
            test_dataset = VLDataset(x_test, y_test, image_processor, text_processor)
            self.train_stream.append(train_dataset)
            self.test_stream.append(test_dataset)
            
        self.n_classes_per_experience = [len(np.unique(self.train_stream[i].labels)) for i in range(self.n_experiences)]
        