import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import pickle
import torch.nn.functional as F
import os
from torch.utils.data import random_split
import numpy as np
import random

class MDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample, lbl = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, lbl

class IKDataset(Dataset):
    def __init__(self, mnist_dataset, transform, seed=0):
        np.random.seed(seed)
        random.seed(seed)
        self.transform = transform
        self.mnist_dataset = mnist_dataset
        self.num_classes = 10
        self.partitioned_indices = self.partition_dataset()
        self.post_load_setup()

    def __len__(self):
        return len(self.final_dataset)

    def process_dataset(self):
        final_dataset = []
        label_counts = [0] * self.num_classes
        # Iterate through the dataset and collect samples
        for idx in range(len(self.mnist_dataset)):
            image, label = self.mnist_dataset[idx]
            action_one_hot = torch.zeros(self.num_classes)
            action = random.randint(0, self.num_classes - 1)
            action_one_hot[action] = 1
            if action == label:
                y_label = (label + 1) % self.num_classes
            else:
                y_label = (label - 1) % self.num_classes

            y_image  = self.get_image_with_label(y_label)
            final_dataset.append((self.transform(image), action_one_hot, self.transform(y_image), label, y_label))

        return final_dataset

    def __getitem__(self, idx):
        return self.final_dataset[idx]

    def partition_dataset(self):
        partitioned_indices = [[] for _ in range(self.num_classes)]
        for idx, (_, label) in enumerate(self.mnist_dataset):
            partitioned_indices[label].append(idx)
        return partitioned_indices

    def get_image_with_label(self, label):
        indices = self.partitioned_indices[label]
        idx = random.choice(indices)
        image, __ = self.mnist_dataset[idx]
        return image

    def post_load_setup(self):
        self.final_dataset = self.process_dataset()

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.post_load_setup()

class IGLDataset(Dataset):
    def __init__(self, mnist_dataset, transform_cls=None, transform_ik=None):
        self.transform_ik = transform_ik
        self.transform_cls = transform_cls
        self.mnist_dataset = mnist_dataset
        self.num_classes = 10
        self.partitioned_indices = self.partition_dataset()
        self.post_load_setup()

    def __len__(self):
        return len(self.final_dataset)

    def process_dataset(self):
        final_dataset = []
        label_counts = [0] * self.num_classes
        # Iterate through the dataset and collect samples
        for idx in range(len(self.mnist_dataset)):
            image, label = self.mnist_dataset[idx]
            c_label = (label + 1) % self.num_classes
            w_label = (label - 1) % self.num_classes

            c_image, w_image = self.get_image_with_label(c_label), self.get_image_with_label(w_label)
            final_dataset.append((self.transform_cls(image), self.transform_ik(image), self.transform_ik(c_image), self.transform_ik(w_image), label))

        return final_dataset

    def __getitem__(self, idx):
        return self.final_dataset[idx]

    def partition_dataset(self):
        partitioned_indices = [[] for _ in range(self.num_classes)]
        for idx, (_, label) in enumerate(self.mnist_dataset):
            partitioned_indices[label].append(idx)
        return partitioned_indices

    def get_image_with_label(self, label):
        indices = self.partitioned_indices[label]
        idx = random.choice(indices)
        image, __ = self.mnist_dataset[idx]
        return image

    def post_load_setup(self):
        self.final_dataset = self.process_dataset()

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.post_load_setup()

class OfflineDataset(Dataset):
    def __init__(self, mnist_dataset, transform_cls=None, transform_ik=None):
        self.transform_ik = transform_ik
        self.transform_cls = transform_cls
        self.mnist_dataset = mnist_dataset
        self.num_classes = 10
        self.partitioned_indices = self.partition_dataset()
        self.post_load_setup()

    def __len__(self):
        return len(self.final_dataset)

    def process_dataset(self):
        final_dataset = []
        # Iterate through the dataset and collect samples
        for idx in range(len(self.mnist_dataset)):
            image, label = self.mnist_dataset[idx]
            action = random.randint(0, self.num_classes - 1)
            if action == label:
                y_label = (label + 1) % self.num_classes
            else:
                y_label = (label - 1) % self.num_classes

            y_image  = self.get_image_with_label(y_label)
            final_dataset.append((self.transform_cls(image), self.transform_ik(image), action, self.transform_ik(y_image), label))

        return final_dataset

    def __getitem__(self, idx):
        return self.final_dataset[idx]

    def partition_dataset(self):
        partitioned_indices = [[] for _ in range(self.num_classes)]
        for idx, (_, label) in enumerate(self.mnist_dataset):
            partitioned_indices[label].append(idx)
        return partitioned_indices

    def get_image_with_label(self, label):
        indices = self.partitioned_indices[label]
        idx = random.choice(indices)
        image, __ = self.mnist_dataset[idx]
        return image

    def post_load_setup(self):
        self.final_dataset = self.process_dataset()

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.post_load_setup()

class TestCorrectDataset(Dataset):
    def __init__(self, mnist_dataset, transform):
        self.transform = transform
        self.mnist_dataset = mnist_dataset
        self.num_classes = 10
        self.partitioned_indices = self.partition_dataset()
        self.post_load_setup()

    def __len__(self):
        return len(self.final_dataset)

    def process_dataset(self):
        final_dataset = []
        label_counts = [0] * self.num_classes
        for idx in range(len(self.mnist_dataset)):
            image, label = self.mnist_dataset[idx]
            action_one_hot = torch.zeros(self.num_classes)
            action_one_hot[label] = 1
            
            y_label = (label + 1) % self.num_classes
            y_image  = self.get_image_with_label(y_label)
            final_dataset.append((self.transform(image), action_one_hot, self.transform(y_image), label, y_label))

            # Add the sample to the final dataset if the label count is less than 2000
        return final_dataset

    def __getitem__(self, idx):
        return self.final_dataset[idx]

    def partition_dataset(self):
        partitioned_indices = [[] for _ in range(self.num_classes)]
        for idx, (_, label) in enumerate(self.mnist_dataset):
            partitioned_indices[label].append(idx)
        return partitioned_indices

    def get_image_with_label(self, label):
        indices = self.partitioned_indices[label]
        idx = random.choice(indices)
        image, __ = self.mnist_dataset[idx]
        return image

    def post_load_setup(self):
        self.final_dataset = self.process_dataset()

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.post_load_setup()

class TestWrongDataset(Dataset):
    def __init__(self, mnist_dataset, transform):
        self.transform = transform
        self.mnist_dataset = mnist_dataset
        self.num_classes = 10
        self.partitioned_indices = self.partition_dataset()
        self.post_load_setup()

    def __len__(self):
        return len(self.final_dataset)

    def process_dataset(self):
        final_dataset = []
        label_counts = [0] * self.num_classes
        # Iterate through the dataset and collect samples
        for idx in range(len(self.mnist_dataset)):
            image, label = self.mnist_dataset[idx]

            action_list = []
            for i in range(self.num_classes):
                if i != label:
                    action_list.append(i)

            action_idx = random.randint(0, self.num_classes-2)
            action = action_list[action_idx]
          
            action_one_hot = torch.zeros(self.num_classes)
            action_one_hot[action] = 1
            y_label = (label - 1) % self.num_classes
            y_image = self.get_image_with_label(y_label)
            final_dataset.append((self.transform(image), action_one_hot, self.transform(y_image), label, y_label))

            # Add the sample to the final dataset if the label count is less than 2000
        return final_dataset

    def __getitem__(self, idx):
        return self.final_dataset[idx]

    def partition_dataset(self):
        partitioned_indices = [[] for _ in range(self.num_classes)]
        for idx, (_, label) in enumerate(self.mnist_dataset):
            partitioned_indices[label].append(idx)
        return partitioned_indices

    def get_image_with_label(self, label):
        indices = self.partitioned_indices[label]
        idx = random.choice(indices)
        image, __ = self.mnist_dataset[idx]
        return image

    def post_load_setup(self):
        self.final_dataset = self.process_dataset()

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.post_load_setup()




