import torch
from torch.utils.data import Dataset
import torchvision.datasets as datasets
from torchvision import transforms
import random
from constants import DEFAULT_MODULO
import numpy as np
from PIL import Image
import itertools

def one_hot_encode(number, size):
    one_hot = torch.zeros(size)
    one_hot[number] = 1
    return one_hot

def binary_encode(num, bits):
    return torch.tensor([int(b) for b in format(num, f'0{bits}b')])

def generate_random_one_hot(length):
    index = torch.randint(0, length, (1,)).item()
    one_hot_vector = torch.zeros(length)
    one_hot_vector[index] = 1
    return one_hot_vector

def unique_random_combinations(num_features, num_samples):
    seen = set()
    domain = [0, 1]

    if num_samples> 2**num_features:
        print(f"Number of samples > Possible combinations, setting num_samples to {2**num_features}")
        num_samples = 2**num_features
    while len(seen) < num_samples:
        combination = tuple(random.choice(domain) for _ in range(num_features))
        if combination not in seen:
            seen.add(combination)
            yield combination


class AlgorithmicDataset(Dataset):
    def __init__(self, operation, p=DEFAULT_MODULO, input_size=None, output_size=None):
        self.p = p
        if not input_size:
            self.input_size = p
        else:
            self.input_size = input_size
        self.output_size = output_size if output_size else self.p
        self.operation = operation
        
        self.data = []
        self.targets = []

        for x in range(0,self.input_size):
            for y in range(0,self.input_size):

                if 'div' in operation.__name__ and y == 0:
                    continue
                result = self.operation(x, y) % self.p 
                
                x_one_hot = one_hot_encode(x, self.input_size)
                y_one_hot = one_hot_encode(y, self.input_size)
                result_one_hot = one_hot_encode(result, self.output_size)
                
                combined_input = torch.cat((x_one_hot, y_one_hot), 0)
                self.data.append(combined_input)
                self.targets.append(result_one_hot)
        self.data = torch.stack(self.data)
        self.targets = torch.stack(self.targets)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    

class BinaryAlgorithmicDataset(Dataset):
    def __init__(self, operation, p=DEFAULT_MODULO, input_size=None, output_size=None):
        self.p = p
        if not input_size:
            self.input_size = p
        else:
            self.input_size = input_size
        self.output_size = output_size if output_size else self.p
        self.operation = operation
        
        self.data = []
        self.targets = []

        input_bits = (self.input_size - 1).bit_length()
        output_bits = (self.output_size - 1).bit_length()

        for x in range(0, self.input_size):
            for y in range(0, self.input_size):
                if 'div' in operation.__name__ and y == 0:
                    continue
                result = self.operation(x, y) % self.p 
                
                x_binary = binary_encode(x, input_bits)
                y_binary = binary_encode(y, input_bits)
                result = one_hot_encode(result, self.output_size)
                
                combined_input = torch.cat((x_binary, y_binary), 0)
                self.data.append(combined_input)
                self.targets.append(result)
        
        self.data = torch.stack(self.data)
        self.targets = torch.stack(self.targets)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    
class ScalarAlgorithmicDataset(Dataset):
    def __init__(self, operation, p=DEFAULT_MODULO, input_size=None, output_size=None):
        self.p = p
        if not input_size:
            self.input_size = p
        else:
            self.input_size = input_size
        self.output_size = output_size if output_size else self.p
        self.operation = operation
        
        self.data = []
        self.targets = []

        input_bits = (self.input_size - 1).bit_length()
        output_bits = (self.output_size - 1).bit_length()

        for x in range(0, self.input_size):
            for y in range(0, self.input_size):
                if 'div' in operation.__name__ and y == 0:
                    continue
                result = self.operation(x, y) % self.p 
                result = one_hot_encode(result, self.output_size)
                
                combined_input = torch.tensor(((x/113), (y/113)))
                self.data.append(combined_input)
                self.targets.append(result)
        
        self.data = torch.stack(self.data)
        self.targets = torch.stack(self.targets)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    
class SpuriousAlgorithmicDataset(Dataset):
    def __init__(self, operation, p=DEFAULT_MODULO, input_size=None, output_size=None):
        self.p = p
        if not input_size:
            self.input_size = p
        else:
            self.input_size = input_size
        self.output_size = output_size if output_size else self.p
        self.operation = operation
        
        self.data = []
        self.targets = []
        self.bias_conflicting = torch.rand((p,p))
        for x in range(1,self.input_size):
            for y in range(1,self.input_size):

                if 'div' in operation.__name__ and y == 0:
                    continue
                result = self.operation(x, y) % self.p 
                
                x_one_hot = one_hot_encode(x, self.input_size)
                y_one_hot = one_hot_encode(y, self.input_size)
                result_one_hot = one_hot_encode(result, self.output_size)
                spurious_feature = result_one_hot if self.bias_conflicting[x,y] <0.97 else generate_random_one_hot(p)
                combined_input = torch.cat((x_one_hot, y_one_hot, spurious_feature), 0)
                self.data.append(combined_input)
                self.targets.append(result_one_hot)
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

class AlgorithmicDatasetTransformer(Dataset):
    def __init__(self, operation, p=DEFAULT_MODULO, input_size=None, output_size=None):
        self.p = p
        if not input_size:
            self.input_size = p
        else:
            self.input_size = input_size
        self.output_size = output_size if output_size else self.p
        self.operation = operation
        
        self.data = torch.tensor([(i, j, p) for i in range(1, p) for j in range(1, p)])
        self.targets = torch.nn.functional.one_hot(torch.tensor([operation(i, j) for (i, j, _) in self.data], dtype=torch.long), p).float()
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    
class SpuriousParityDataset(Dataset):
    def __init__(self, num_features, num_spurious, num_noise=0, num_samples=None):
        self.num_features = num_features
        self.num_spurious = num_spurious
        if not num_samples:
            self.num_samples = 2**num_features
        else:
            self.num_samples = int(num_samples)
        self.data = torch.tensor(list(unique_random_combinations(num_features+num_noise, self.num_samples)))

        self.targets = (self.data[:,:num_features].sum(dim=1)%2).float()
        spurious_feature = torch.zeros(self.num_samples)

        bias_conflicting = torch.rand(self.num_samples)>0.97
        spurious_feature[~bias_conflicting] = self.targets[~bias_conflicting]
        spurious_feature[bias_conflicting] = (self.targets[bias_conflicting] - 1).abs()
        self.spurious_feature = spurious_feature
        self.data = torch.hstack((self.data, spurious_feature.view(-1, 1)))*2-1
        self.targets = torch.nn.functional.one_hot(self.targets.long(), 2).float()
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx], self.spurious_feature[idx]
    

class SparseParityDataset(Dataset):
    def __init__(self, num_features, num_noise_features, num_samples=None, binariy_noise=False):
        self.num_features = num_features
        self.num_noise_features = num_noise_features
        self.num_samples = num_samples
    
        self.data = torch.tensor(list(unique_random_combinations(num_features + self.num_noise_features, self.num_samples)))
        self.targets = (self.data[:,:num_features].sum(dim=1)%2).float()
        self.targets = torch.nn.functional.one_hot(self.targets.long(), 2).float()
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]