import sys
sys.path.append("/Users/kosio/Repos/network-relative-homology/src/")
from NetRelHom import *
from TopologicalMethods import *
from torch.utils.data import Dataset, TensorDataset
import torch.optim as optim
from tqdm import tqdm
import persim

def optimize_nn(model, dat_loader, optim, crit, num_epochs = 100):
    loss_hist = []
    for epoch in tqdm(range(num_epochs)):
        for dat in dat_loader:
            inputs, targets = dat
            # Forward pass
            outputs = model(inputs.reshape(len(inputs),-1))[-1]
            loss = crit(outputs, targets) #+ 0.001*torch.abs(outputs).sum() #require sparse outputs?
    
            # Backward pass and optimization
            optim.zero_grad()
            loss.backward()
            optim.step()
            loss_hist.append(loss.item())
        if (epoch+1) % 5 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.5f}')
    return outputs, loss_hist

#%%
from torchvision import datasets, transforms
from sklearn.decomposition import PCA

# n_dims = 50
transforms=transforms.Compose([transforms.ToTensor()])#, 
#                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) #CIFAR10 setup

data_train = datasets.MNIST('/Users/kosio/Data/MNIST/', train=True, transform=transforms, download=True)
data_test = datasets.MNIST('/Users/kosio/Data/MNIST/', train=False, transform=transforms,download=True)

#Apply PCA to the data
# data_train.data = torch.Tensor(PCA(n_components=n_dims).fit_transform(data_train.data.reshape(len(data_train.data),-1)))
# data_train.data = (data_train.data-data_train.data.mean())/data_train.data.std()
# data_test.data = torch.Tensor(PCA(n_components=n_dims).fit_transform(data_test.data.reshape(len(data_test.data),-1)))
# data_test.data = (data_test.data-data_test.data.mean())/data_test.data.std()

trainloader = torch.utils.data.DataLoader(data_train, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(data_test, batch_size=64, shuffle=True)

trainloader_full = torch.utils.data.DataLoader(data_train, batch_size=len(data_train), shuffle=True)
testloader_full = torch.utils.data.DataLoader(data_test, batch_size=len(data_test), shuffle=True)

#%%Or different datasets
net_width = 10
hidden_sizes=[net_width]*2
decompositions = []
homs = []
out_homs = []
criterion = nn.CrossEntropyLoss()


model = FeedforwardNetwork(input_size=784,hidden_sizes=hidden_sizes,
                               out_layer_sz=10, init_type='none',activation=nn.ReLU)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
outputs, loss_hist = optimize_nn(model, trainloader, optimizer, criterion, num_epochs=100)


    # decomposer = NetworkDecompositions(model)
    # decomps = decomposer.compute_overlap_decomp(X_test,1)
    # decompositions.append(decomposer)
    
    # d, diag, cycles = Phom.relative_homology(X_test, decomps[-1],pairwise_distances,False,[1,None])
    # d_out, diag_out, _ = Phom.homology_analysis(out_knots[-1].T, pairwise_distances, False, [1,None])
    
    # homs.append(diag)
    # out_homs.append(Phom.normalize(diag_out))
#%%
n_points = 1
X = torch.vstack([data_train.data, data_test.data])[:20]
y = torch.concatenate([data_train.targets, data_test.targets])[:20]
performance = (sum(torch.argmax(model(X.reshape(-1,28*28).float())[-1],1)==y)/len(X)).item()


decomposer = NetworkDecompositions(model)
poly_decomp = decomposer.compute_codeword_eq_classes(X.reshape(-1,28*28).float(),-1)
decomps = decomposer.compute_overlap_decomp(X.reshape(-1,28*28).float(),sensitivity=5000)

#%%Convexity test
samples_in_poly = [[] for i in range(len(poly_decomp[0]))]
histograms = [[] for i in range(len(poly_decomp[0]))]
for i, poly_dec in enumerate(poly_decomp[0]):
    for j, poly in enumerate(poly_dec):
        samples_in_poly[i]. append(len(poly))
    histograms[i].append(np.histogram(samples_in_poly[i],bins = np.arange(1,max(samples_in_poly[i])+1)))

fig,ax = plt.subplots(1,1,figsize=(3.5,3.5))
for i in range(len(poly_decomp[0])):
    if len(histograms[i][0][0])!=1:
        ax.plot(histograms[i][0][1][:-1],histograms[i][0][0])
    else:
        ax.plot(histograms[i][0][1][:-1],histograms[i][0][0],'s-')
ax.set_yscale('log')
# ax.set_xscale('log')
# ax.grid('on')
ax2 = fig.add_axes([0.55, 0.55, 0.25, 0.25])
ax2.plot(histograms[-1][0][1][:-1],histograms[-1][0][0],color='#d62728')
ax2.set_yscale('log')
ax.set_xlabel('# of points in region')
ax.set_title(f'width = {net_width}, performance = {round(performance,3)}')
# fig.savefig(f'../figures/convexity_{net_width}_net.png',dpi=500, transparent=True,bbox_inches='tight')
#%%Hypersample nearby numbers
A = decomposer.polyhedral_decomposition[-1][0].A
b = decomposer.polyhedral_decomposition[-1][0].b
new_samples = decomposer.populate_polyhedra((A,b),1,shrink_factor=1e-16)

