import torch
import torch.nn as nn
from tqdm import tqdm
import cupy as cp
import imageio
import os
import torchvision, torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from  torchvision.datasets import MNIST
import numpy as np
import cupy as cp
import inv_mean
import matplotlib.pyplot as plt

#### loading MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.RandomAffine(degrees=(0, 90), translate=(0.1, 0.2))])
# Load and transform data
trainset = MNIST('/train', train=True, download=True, transform = transform)
testset = MNIST('/test', train=False, download=True, transform = transform)

#### Testing on a set of images of same class
import torch.utils.data as data_utils
idx = trainset.targets== 3
idxt = torch.tensor(np.where(idx==True)[0][0:10])
print(idxt)
trainset1 = data_utils.Subset(trainset, idxt)
train10Num = np.zeros((28,28,1,10))

for i in range(len(trainset1)):
  train10Num[:,:,:,i] = torch.permute(trainset1[i][0], (1,2,0))
for i in range(10):
  plt.subplot(2,5, i+1)
  plt.imshow(train10Num[:,:,:,i])
plt.show()

train10Num = cp.array(train10Num[:,:,:,])
t1 = time.time()
tau_u, tau_v, all_phi, all_b, mu, obj_vals = invariant_mean_batch_sigma_schedule_cupy(V = train10Num, sigma = 14, sigma_schedule = True,
                                                                                      center = cp.array([[14],[14]]), t = 0.1, MAX_ITER = 70 , weights = cp.ones(3), generate_gif = True)
t2 = time.time()
print(t2-t1)

plt.imshow(mu.get())