import sys
sys.path.append("/Users/kosio/Repos/network-relative-homology/src/")
from NetRelHom import *
from TopologicalMethods import *
from torch.utils.data import Dataset, TensorDataset
import torch.optim as optim
from tqdm import tqdm
import persim

def optimize_nn(model, dat_loader, optim, crit, num_epochs = 100):
    loss_hist = []
    for epoch in tqdm(range(num_epochs)):
        for dat in dat_loader:
            inputs, targets = dat
            # Forward pass
            outputs = model(inputs[None])[-1]
            loss = crit(outputs[0], targets) #+ 0.001*torch.abs(outputs).sum() #require sparse outputs?
    
            # Backward pass and optimization
            optim.zero_grad()
            loss.backward()
            optim.step()
            loss_hist.append(loss.item())
        if (epoch+1) % 5 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.5f}')
    return outputs, loss_hist
#%%
from sklearn.datasets import load_digits, load_breast_cancer, load_iris
X = torch.Tensor(load_iris().data)
y = torch.Tensor(load_iris().target).type(torch.LongTensor)
data_train = TensorDataset(X,y)
trainloader = torch.utils.data.DataLoader(data_train, batch_size=8, shuffle=True)
criterion = nn.CrossEntropyLoss()

hidden_sizes = [X.shape[1]*4,X.shape[1]*4]
model = FeedforwardNetwork(input_size=X.shape[1],hidden_sizes=hidden_sizes,
                               out_layer_sz=len(np.unique(y)), init_type='none',activation=nn.ReLU)

optimizer = optim.Adam(model.parameters(), lr=0.0001)
outputs, loss_hist = optimize_nn(model, data_train, optimizer, criterion, num_epochs=100000)


decomposer = NetworkDecompositions(model)
decomps = decomposer.compute_overlap_decomp(X,sensitivity=100)

#%%
Phom = PersistentHomology()
hom = Phom.homology_analysis(X, pairwise_distances, False, [1,None])
hom_quot = Phom.relative_homology(X, decomps[-1], pairwise_distances, False, [1,None])

#%%

plot_diagrams(hom[1])

plot_diagrams(hom_quot[1])