import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
from PIL import Image
import pickle

from random import randint

from matplotlib import pyplot as plt
from model import VAE, loss_fn

def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        if filename.endswith('.png'):
            img_path = os.path.join(folder, filename)
            img = Image.open(img_path)
            img_array = np.array(img)
            images.append(img_array)
    return np.stack(images)



def batch_generator(data, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

if __name__ == '__main__':

    epochs = 500
    bs = 64
    
    transform_segmented = transforms.Compose([           
                                transforms.Resize((64, 64)),
                                #transforms.Normalize(mean=[0.24787295, 0.23788302, 0.21599289], std=[0.18843222, 0.17737571, 0.17882799])              
                                ])
    
    
    cell_size = 14
    #Load images with pickle 
    images1 = pickle.load(open('pickled_output/all_segments_14__760000.pkl', 'rb'))


    images1 = np.stack(images1).reshape(-1, cell_size,cell_size, 3)
    print(images1.shape)
    #for img in images1:
    #    plt.imshow(img)
    #    plt.show()
    images1 = transform_segmented(torch.tensor(np.transpose(images1, (0,3,1,2)))).numpy()
    plt.imshow(np.moveaxis(images1[10], 0, -1))
    plt.show()

    print("CONCATENATING IMAGES")
    images = images1
    print(images.shape)

    print("SPLITTING DATASET")

    X_train = images
    print(X_train.shape)

    #get_normalization(X_train)
    model = VAE(image_channels=3).to('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 

    for epoch in range(epochs):
        train_batches = batch_generator(X_train, bs)

        for batch in train_batches:
            #print(batch.shape)
            images = transform_segmented(torch.tensor(batch).cuda().float())

            recon_images, mu, logvar = model(images)
            loss, bce, kld = loss_fn(recon_images, images, mu, logvar)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            to_print = "Epoch[{}/{}] Loss: {:.4f} {:.4f} {:.4f}".format(epoch+1, 
                                    epochs, loss.item()/bs, bce.item()/bs, kld.item()/bs)
            print(to_print)

        if epoch % 10 == 0:
            print("Saving model")
            torch.save(model.state_dict(), 'model_outputs/vae_14_14_3_'+str(epoch)+'.torch')
