import torchvision
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch
import os
import numpy as np
import torch.nn.functional as F

import numpy as np
import math
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from typing import Any, Callable, Optional, Tuple
import pickle
import os.path
#import cv2

from torch.distributions.multivariate_normal import MultivariateNormal

from torch import Tensor


from scipy.stats import special_ortho_group

class Dataset(Dataset):
    def __init__(self, images, labels):
        self.X = images
        self.y = labels
#        self.transform = transform
    def __len__(self):
        return (len(self.X))

    def __getitem__(self, i):
        data = self.X[i,:,:]

        #if self.transform:
        #    data = self.transform(data)

        return (data, self.y[i])

def project(sigma,samples,len1,len2):
    for i in range(len1):
        for j in range(len2): 
            samples[i][j]=torch.matmul(sigma,samples[i][j].t())    

def load_data(args):

    # The seed we generate the data is always fixed
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.cuda.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    y_train = torch.ones(args.m_train,dtype=torch.int32)
    y_test = torch.ones(args.m_test,dtype=torch.int32)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    ##create features
    W=special_ortho_group.rvs(args.d)
    W=torch.tensor(W[:,:args.k], dtype=torch.float32)
    #print("norm of the features")
    #print(torch.matmul(W.t(),W))
    
    #W can be saved and loaded to make sure we got the same W in every run
    torch.save(W,"data/W")
    #W=torch.load("results/W")
 
    print("features")
    print(W)

    # Find the covariance matrix of the noise
    tmp = 0
    for i in range(args.k):
        tmp+=torch.matmul(W[:,i].reshape(-1,1),W[:,i].reshape(1,-1))

    cov_noise = (1/np.power(args.d,5/12)**2)*torch.eye(args.d).sub(tmp)
    mean = torch.zeros(args.d)
    distrib = MultivariateNormal(loc=mean, covariance_matrix=torch.eye(args.d))


    X_train = distrib.rsample(sample_shape=(args.m_train,args.P))
    project(cov_noise,X_train,args.m_train,args.P)
    X_test =  distrib.rsample(sample_shape=(args.m_test,args.P))
    project(cov_noise,X_test,args.m_test,args.P)

    X_train.requires_grad=False
    X_test.requires_grad=False


    count=0

    pos=[[1],[-1]]
    
    for s in range(4):
        npos=[]
        for i in [1,-1]:
            for p in pos:
                p.append(i)
                npos.append(p.copy())
                p.pop()
        pos=npos

    print("pos")
    print(pos)

    for i in range(args.m_train):
        #label=0
        label=1
        A=[]
        

        for j,l in enumerate(np.random.permutation(args.P)[:args.k]):
            r_sign=0
            if j>=0:
                r_sign=pos[count%32][j]
            else:
                r_sign=np.sign(np.random.uniform()-0.5)
            A.append(r_sign)
            X_train[i,l,:]= W[:,j]*r_sign
            #label+=r_sign
            label*=r_sign
        label=sum(A[:2])+math.prod(A[2:])
        y_train[i]=np.sign((np.sign(label)+1)/2)
        if count<10:       
            print(A)        
        count+=1  
    y_train.requires_grad=False



    for i in range(args.m_test):
        #label=0
        label=1
        A=[]
        for j,l in enumerate(np.random.permutation(args.P)[:args.k]):
            r_sign=np.sign(np.random.uniform()-0.5)
            A.append(r_sign)
            X_test[i,l,:]= W[:,j]*r_sign
            #label+=r_sign
            label*=r_sign
        label=sum(A[:2])+math.prod(A[2:])   
        y_test[i]=np.sign((np.sign(label)+1)/2)    
        
    y_test.requires_grad=False




    y_train= y_train.type(torch.LongTensor)
    y_test=y_test.type(torch.LongTensor)
    
    #data can be saved and loadede to make sure we use the same data in each run
    torch.save(y_train,"data/y_train_"+str(args.P))

    torch.save(y_test,"data/y_test_"+str(args.P))

    torch.save(X_train,"data/X_train_"+str(args.P))

    torch.save(X_test,"data/X_test_"+str(args.P))



    train_set = Dataset(X_train,y_train)#, transform=transform)
    test_set = Dataset(X_test,y_test)#, transform=transform)
    

    trainloader = DataLoader(train_set, batch_size=args.train_batch_size,\
            shuffle=True, num_workers=1)
    testloader = DataLoader(test_set, batch_size=args.m_test,\
            shuffle=False, num_workers=1)


    return trainloader, testloader, W





def sample_hypersphere(
    d: int,
    n: int = 1,
    seed: Optional[int] = None,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    r"""Sample uniformly from a unit d-sphere.

    Args:
        d: The dimension of the hypersphere.
        n: The number of samples to return.
        seed: If provided, use as a seed for the RNG.
        device: The torch device.
        dtype:  The torch dtype.

    Returns:
        An  `n x d` tensor of uniform samples from from the d-hypersphere.

    Example:
        >>> sample_hypersphere(d=5, n=10)
    """
    dtype = torch.float if dtype is None else dtype
    if d == 1:
        rnd = torch.randint(0, 2, (n, 1), device=device, dtype=dtype)
        return 2 * rnd - 1
    torch.manual_seed(seed)    
    rnd = torch.randn(n, d, dtype=dtype)
    samples = rnd / torch.norm(rnd, dim=-1, keepdim=True)
    if device is not None:
        samples = samples.to(device)
    return samples


def is_psd(mat):
    return bool((mat == mat.T).all() and (torch.eig(mat)[0][:,0]>=0).all())
