import math
import torch
import math
import torch.nn.functional as F
from torch.utils.data import TensorDataset
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader


class rotMNIST:
    def __init__(self,mean_norm,std_norm,train=True):
        
        transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((mean_norm,),(std_norm,))])

        self.pre_rot_dataset=torchvision.datasets.MNIST(root='./datasets/',download=True,train=train,transform=transform)

    def create_dataset(self):
        #torch.manual_seed(0)
        deterministic_loader=DataLoader(self.pre_rot_dataset,batch_size=len(self.pre_rot_dataset),shuffle=False)
        once_iter=iter(deterministic_loader)
        data, labels=next(once_iter)
        N=data.shape[0]
        assert N==len(self.pre_rot_dataset)
        angles = torch.rand(N)* 2 * math.pi
        with torch.no_grad():
    
            affineMatrices = torch.zeros(N,2,3)
            affineMatrices[:,0,0] = angles.cos()
            affineMatrices[:,1,1] = angles.cos()
            affineMatrices[:,0,1] = angles.sin()
            affineMatrices[:,1,0] = -angles.sin()
    
            flowgrid = F.affine_grid(affineMatrices, size = data.size(),align_corners=True)
            rot_data = F.grid_sample(data, flowgrid,align_corners=True)

        return TensorDataset(rot_data,labels)

