
import torch
import sys
import os
import time
import numpy as np

import torch.nn as nn
import torch.nn.functional as F


PROJECT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
DIS_PROJECT_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "disentanglement_lib")
sys.path.append(DIS_PROJECT_PATH)
sys.path.append(PROJECT_PATH)
sys.path.append("/home/learning-group-structure/added_modules/")


os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from experiments.disentanglement_lib import evaluate_metrics_dis_lib
from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
from data import data_loader
from experiments.quessard import quessard_utils
from modules.general_metric import general_metric


start_repetition = 0
trained_repetitions = 1
dataset_list = ["coil100"]
dataset_name = "coil100"
num_train = 100
metric_estimate_repetitions = 10 # Number of repetitions done to estimate a metric (they depend on the seed)
num_test = num_train // 2
z_dim = 7
architecture = "dis_lib"

for dataset in dataset_list:
    print("Evaluating dataset", dataset)
    ground_truth_data = get_named_ground_truth_data(dataset)
    for repetition in range(start_repetition, trained_repetitions):
        print("Evaluating repetition", repetition)
        model_path = os.path.join("/home/learning-group-structure/results_dis_lib", dataset_name, "models")
        model_filename = dataset_name + "_" + str(repetition) + "_encoder.pth"
        encoder = torch.load(os.path.join(model_path, model_filename))


        class Encoder(nn.Module):

            def __init__(self,
                         n_out=4,
                         n_channels=3,
                         image_size=(64, 64),
                         conv_hid=32,
                         conv_kernel=(4, 4),
                         conv_stride=(2, 2), ):
                super().__init__()

                self.conv1 = encoder.conv1
                self.conv2 = encoder.conv2
                self.conv3 = encoder.conv3
                self.conv4 = encoder.conv4
                final_size = np.product((conv_hid * 2, 64 // (2 ** 4), 64 // (2 ** 4)))
                self.fc1 = encoder.fc1
                self.fc2 = encoder.fc2

            def forward(self, x):
                x = F.relu(self.conv1(x))
                x = F.relu(self.conv2(x))
                x = F.relu(self.conv3(x))
                x = F.relu(self.conv4(x))
                x = torch.flatten(x, 1)
                x = F.relu(self.fc1(x))
                x = self.fc2(x)
                return F.normalize(x).squeeze()


        encoder = Encoder()
        encoder.cuda()
        print(encoder)


        def representation_function(img):
            z = encoder(torch.from_numpy(img).permute(0, -1, 1, 2).float().cuda())  # .numpy()
            latent = z.detach().cpu().numpy()
            return latent


        # Saving folder
        ICML_RESULTS_PATH = "/home/ICML/results"
        experiment_metrics_save_path = os.path.join(ICML_RESULTS_PATH, dataset, "quessard", architecture, str(repetition))
        os.makedirs(experiment_metrics_save_path, exist_ok=True)

        start_time = time.time()
        # Calculate LSBD metric
        if dataset == "modelnet40":
            dataset_parameters = {

                "root_path": "/data/aligned64",
                "data": "modelnet40",
                "collection_list": [
                    "airplane"],
                "data_type": "train",
                "dataset_directory": ""

            }
        elif dataset=="coil100":
            dataset_parameters = {
                "root_path": PROJECT_PATH,
                "data": "coil100",
            }
        else:
            dataset_parameters = {}
        evaluate_metrics_dis_lib.evaluate_metrics(representation_function, ground_truth_data, metric_estimate_repetitions, num_train,
                                                  num_test, experiment_metrics_save_path,
                                                  )
        # images_dataset = data_loader.load_factor_data(**dataset_parameters)
        #
        # latent_embeddings = quessard_utils.produce_embeddings(images_dataset, encoder, z_dim, device="cpu")
        # k_values = np.arange(-10, 11)
        # lsbd_score, k_min = general_metric.calculate_metric_rotations(latent_embeddings, k_values, verbose=1)
        print("--- %s seconds ---" % (time.time() - start_time))