import torchvision
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch
import os
import numpy as np
import torch.nn.functional as F

import numpy as np
import math
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from typing import Any, Callable, Optional, Tuple
import pickle
import os.path
#import cv2

from torch.distributions.multivariate_normal import MultivariateNormal

from torch import Tensor


from scipy.stats import special_ortho_group

class Dataset(Dataset):
    def __init__(self, images, labels):
        self.X = images
        self.y = labels
#        self.transform = transform
    def __len__(self):
        return (len(self.X))

    def __getitem__(self, i):
        data = self.X[i,:,:]

        #if self.transform:
        #    data = self.transform(data)

        return (data, self.y[i])

def project(sigma,samples,len1,len2):
    for i in range(len1):
        for j in range(len2): 
            samples[i][j]=torch.matmul(sigma,samples[i][j].t())    

def load_data(args):


    W=torch.load("data/W")



    X_train = torch.load("data/X_train_"+str(args.P))

    X_test = torch.load("data/X_test_"+str(args.P))


     
    y_train = torch.load("data/y_train_"+str(args.P))

    y_test = torch.load("data/y_test_"+str(args.P))


    train_set = Dataset(X_train,y_train)#, transform=transform)
    test_set = Dataset(X_test,y_test)#, transform=transform)
    

    trainloader = DataLoader(train_set, batch_size=args.train_batch_size,\
            shuffle=True, num_workers=1)
    testloader = DataLoader(test_set, batch_size=args.m_test,\
            shuffle=False, num_workers=1)


    return trainloader, testloader, W





def sample_hypersphere(
    d: int,
    n: int = 1,
    seed: Optional[int] = None,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    r"""Sample uniformly from a unit d-sphere.

    Args:
        d: The dimension of the hypersphere.
        n: The number of samples to return.
        seed: If provided, use as a seed for the RNG.
        device: The torch device.
        dtype:  The torch dtype.

    Returns:
        An  `n x d` tensor of uniform samples from from the d-hypersphere.

    Example:
        >>> sample_hypersphere(d=5, n=10)
    """
    dtype = torch.float if dtype is None else dtype
    if d == 1:
        rnd = torch.randint(0, 2, (n, 1), device=device, dtype=dtype)
        return 2 * rnd - 1
    torch.manual_seed(seed)    
    rnd = torch.randn(n, d, dtype=dtype)
    samples = rnd / torch.norm(rnd, dim=-1, keepdim=True)
    if device is not None:
        samples = samples.to(device)
    return samples


def is_psd(mat):
    return bool((mat == mat.T).all() and (torch.eig(mat)[0][:,0]>=0).all())
