import numpy as np
import pandas as pd
import pickle
import os
from functools import reduce
import torch
from glob import glob
from itertools import cycle
from pathlib import Path
from PIL import Image
from tqdm import tqdm


class ImdbDataset:
    def __init__(self, config, split='train'):
        data_paths = {'train': "datasets/aclImdb/train", 'eval': "datasets/aclImdb/test"}
        split_path = data_paths[split]
        neg_path = split_path + "/neg"
        pos_path = split_path + "/pos"
        neg_inputs = zip(glob(neg_path + "/*.txt"), cycle([0]))
        pos_inputs = zip(glob(pos_path + "/*.txt"), cycle([1]))
        self.data = np.random.permutation(list(neg_inputs) + list(pos_inputs))

        self.tokenizer = config.tokenizer
        self.max_length = config.max_length

    def __getitem__(self, i):
        data = self.data[i]
        with open(data[0], 'r') as fo:
            source = fo.read()
        inputs = self.tokenizer(source, max_length=self.max_length)
        target = int(data[1])
        return inputs, torch.LongTensor([target])

    def __len__(self):
        return len(self.data)


class ListOpsDataset:
    def __init__(self, config, split='train'):
        data_paths = {'train': "datasets/lra_release/lra_release/listops-1000/basic_train.tsv",
                      'eval': "datasets/lra_release/lra_release/listops-1000/basic_val.tsv",
                      'test': "datasets/lra_release/lra_release/listops-1000/basic_test.tsv"}

        self.data = pd.read_csv(data_paths[split], delimiter='\t')
        self.tokenizer = config.tokenizer
        self.max_length = config.max_length

    def __getitem__(self, i):
        data = self.data.iloc[i]
        source = data.Source
        inputs = self.tokenizer(source,
                                max_length=self.max_length)  # return_tensors='pt', truncation=True, padding='max_length'
        target = data.Target
        return inputs, torch.LongTensor([target])

    def __len__(self):
        return len(self.data)


'''class Cifar10Dataset:
    def __init__(self, config, split='train'):
        data_paths = {'train': [f"datasets/cifar-10-batches-py/data_batch_{i}" for i in range(1, 6)],
                      'eval': ["datasets/cifar-10-batches-py/test_batch"]
                      }
        print("loading cifar-10 data...")
        data_dicts = [Cifar10Dataset.unpickle(path) for path in data_paths[split]]
        print("assembling cifar-10 files..")
        self.data = reduce((lambda x, y: {b'data': np.concatenate([x[b'data'], y[b'data']], axis=0),
                                          b'labels': np.concatenate([x[b'labels'], y[b'labels']], axis=0)}),
                           data_dicts)
        # TODO CHECK: i think this is the right shape
        # see: https://www.cs.toronto.edu/~kriz/cifar.html
        #      section "Dataset layouts" discusses the memory layout of the array
        self.data[b'data'] = self.data[b'data'].reshape((-1, 3, 1024))

        self.tokenizer = config.tokenizer
        self.max_length = config.max_length

    @staticmethod
    def unpickle(file):
        with open(file, 'rb') as fo:
            d = pickle.load(fo, encoding='bytes')
        return d

    def __getitem__(self, i):
        r, g, b = self.data[b'data'][i]
        # grayscale image (assume pixels in [0, 255])
        source = (0.2989 * r + 0.5870 * g + 0.1140 * b).astype(int)
        inputs = self.tokenizer(source, max_length=self.max_length)
        target = self.data[b'labels'][i]
        return inputs, torch.LongTensor([target])

    def __len__(self):
        return len(self.data[b'data'])'''


class Cifar10Dataset:
    def __init__(self, config, split='train'):
        data_paths = {'train': [f"datasets/cifar-10-batches-py/data_batch_{i}" for i in range(1, 6)],
                      'eval': ["datasets/cifar-10-batches-py/test_batch"]}
        print("loading cifar-10 data...")
        data_dicts = [Cifar10Dataset.unpickle(path) for path in data_paths[split]]
        print("assembling cifar-10 files..")
        self.data = reduce((lambda x, y: {b'data': np.concatenate([x[b'data'], y[b'data']], axis=0),
                                          b'labels': np.concatenate([x[b'labels'], y[b'labels']], axis=0)}),
                           data_dicts)
        # Reshape data: (number of images, 3, 1024) to (number of images, 3 * 1024)
        self.data[b'data'] = self.data[b'data'].reshape((-1, 3 * 1024))

        self.tokenizer = config.tokenizer
        self.max_length = config.max_length

    @staticmethod
    def unpickle(file):
        with open(file, 'rb') as fo:
            d = pickle.load(fo, encoding='bytes')
        return d

    def __getitem__(self, i):
        rgb_flat = self.data[b'data'][i]
        # Concatenate r, g, b channels and flatten them into a single array
        inputs = self.tokenizer(rgb_flat, max_length=self.max_length)
        target = self.data[b'labels'][i]
        return inputs, torch.LongTensor([target])

    def __len__(self):
        return len(self.data[b'data'])


class PathFinderDataset(torch.utils.data.Dataset):
    """Path Finder dataset."""

    def __init__(self, config, split='train', transform=None, metadata_files=None):
        """
        Args:
            config (Config): Configuration object.
            split (str): 'train', 'val', or 'test'.
            transform (callable, optional): Optional transform to be applied.
            metadata_files (list, optional): List of metadata files for this split.
        """
        self.data_dir = "/home/mxm6982/data/codes/AstroTransformer/AstroTransformer/RMAAT/datasets/lra_release/lra_release/pathfinder32/"
        # There's an empty file in the dataset
        self.blacklist = {
            "/home/mxm6982/data/codes/AstroTransformer/AstroTransformer/RMAAT/datasets/lra_release/lra_release/pathfinder32/curv_baseline/imgs/0/sample_172.png"}
        self.data_dir = Path(self.data_dir).expanduser()
        assert self.data_dir.is_dir(), f"data_dir {str(self.data_dir)} does not exist"
        self.transform = transform
        self.metadata_files = metadata_files  # Store the list of metadata files
        self.tokenizer = config.tokenizer  # Store the tokenizer
        self.max_length = config.max_length  # Store the max_length
        samples = []

        # Load data only from the specified metadata files
        for metadata_file in self.metadata_files:
            with open(metadata_file, "r") as f:
                for metadata in f.read().splitlines():
                    metadata = metadata.split()
                    diff_level = Path(metadata_file).parts[-3]
                    sub_folder = metadata[0]
                    image_filename = metadata[1]
                    # Construct correct relative path
                    relative_path = self.data_dir / diff_level / sub_folder / image_filename
                    if (
                            str(Path(self.data_dir.stem) / relative_path)
                            not in self.blacklist
                    ):
                        label = int(metadata[3])
                        # print(f"Appending sample with path: {relative_path}")  # Print the path being appended
                        samples.append((str(relative_path), label))
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        full_path = Path(self.data_dir) / path  # Construct the full path
        # print(f"Attempting to open: {full_path}")  # Print the full path

        try:
            with open(full_path, "rb") as f:
                sample = Image.open(f).convert("L")  # Open in grayscale
        except FileNotFoundError as e:
            print(f"Error opening file: {full_path}")
            raise e  # Re-raise the exception to stop training

        sample = np.array(sample)
        if self.transform is not None:
            sample = self.transform(sample)

        # Tokenize the image (pixel_tokenizer expects a 2D array, which the image already is)
        inputs = self.tokenizer(sample, max_length=self.max_length)

        return inputs, torch.LongTensor([target])


'''class AanDataset(torch.utils.data.Dataset):
    """AAN dataset."""

    def __init__(self, config, split='train'):
        """
        Args:
            config (Config): Configuration object.
            split (str): 'train', 'eval', or 'test'.
        """
        self.data_dir = "/home/mxm6982/data/codes/AstroTransformer/AstroTransformer/RMAAT/datasets/lra_release/lra_release/tsv_data/"
        self.data_dir = Path(self.data_dir).expanduser()
        assert self.data_dir.is_dir(), f"data_dir {str(self.data_dir)} does not exist"
        self.tokenizer = config.tokenizer
        self.max_length = config.max_length
        self.split = split
        self.batch_size = config.batch_size  # Add Batch Size

        data_files = {
            'train': self.data_dir / "new_aan_pairs.train.tsv",
            'eval': self.data_dir / "new_aan_pairs.eval.tsv",
            'test': self.data_dir / "new_aan_pairs.test.tsv"
        }

        # Create a cache file path
        cache_file = data_files[split].with_suffix(".pkl")

        if os.path.exists(cache_file):
            print(f"Loading cached data from {cache_file}")
            with open(cache_file, "rb") as f:
                self.data = pickle.load(f)
        else:
            print(f"Reading data from {data_files[split]} and caching to {cache_file}")
            chunks = []
            with tqdm(desc="Reading TSV in chunks") as pbar:
                for chunk in pd.read_csv(data_files[split], delimiter='\t', header=None, chunksize=10000):
                    chunks.append(chunk)
                    pbar.update(1)
            self.data = pd.concat(chunks)
            with open(cache_file, "wb") as f:
                pickle.dump(self.data, f, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label = self.data.iloc[idx][0]
        text_1 = self.data.iloc[idx][3]
        text_2 = self.data.iloc[idx][4]

        # Combine text_1 and text_2 with [SEP] token
        source = str(text_1) + " [SEP] " + str(text_2)

        inputs = self.tokenizer(source, max_length=self.max_length, truncation=True, padding='max_length')
        return inputs, torch.LongTensor([label])
'''


class AanDataset(torch.utils.data.Dataset):
    """AAN dataset with caching."""

    def __init__(self, config, split='train'):
        self.data_dir = Path(config.data_dir).expanduser()
        assert self.data_dir.is_dir(), f"data_dir {self.data_dir} does not exist"
        self.split = split
        self.max_length = config.max_length
        self.text_1_max_length = config.text_1_max_length if hasattr(config,
                                                                     'text_1_max_length') else self.max_length // 2
        self.text_2_max_length = config.text_2_max_length if hasattr(config,
                                                                     'text_2_max_length') else self.max_length // 2

        data_files = {
            'train': self.data_dir / "new_aan_pairs.train.tsv",
            'eval': self.data_dir / "new_aan_pairs.eval.tsv",
            'test': self.data_dir / "new_aan_pairs.test.tsv"
        }

        cache_file = data_files[split].with_suffix(".cache.pkl")

        if os.path.exists(cache_file):
            print(f"Loading cached data from {cache_file}")
            with open(cache_file, "rb") as f:
                self.data = pickle.load(f)
        else:
            print(f"Reading and caching data to {cache_file}")
            self.data = pd.read_csv(data_files[split], delimiter='\t', header=None)
            self.data['text_1'] = self.data[3].astype(str)  # Ensure correct column indices and string type
            self.data['text_2'] = self.data[4].astype(str)
            self.data['label'] = self.data[0].astype(int)
            with open(cache_file, "wb") as f:
                pickle.dump(self.data, f, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        return sample['text_1'], sample['text_2'], sample['label']
