import os
import sys
import numpy as np
import time
import matplotlib.pyplot as plt
from IPython import display
import torch
import random
import neptune

import torch.nn.functional as F
import torch.optim as optim
import matplotlib.gridspec as gridspec

# os.chdir('/home/learning-group-structure/src/flatland/flat_game')
CWD = os.getcwd()
PROJECT_PATH = os.path.dirname(os.path.dirname(CWD))
CASELLES_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "learning-group-structure")

sys.path.append(PROJECT_PATH)

print("Appended ", PROJECT_PATH)
print("Appended ", CASELLES_PATH)

# Imports from Caselles-Dupre forked code
sys.path.append(os.path.join(CASELLES_PATH, "src"))
from environments import LatentWorld

sys.path.append(os.path.join(CASELLES_PATH, "added_modules"))
from architectures import dis_lib, teapot, vgg
from plotting import plotting

# Imports from our project
import data.data_loader
import data.data_environment
from modules.general_metric import general_metric
from experiments import neptune_config
from modules.utils.plotting import yiq_embedding, plot_latent_dimension_combinations

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
if torch.cuda.is_available():
    dev = "cuda:0"
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    dev = "cpu"
print("Device", dev)
device = torch.device(dev)
# from env import Env
start_repetition = 0
end_repetition = 10
repetitions = 10
n_sgd_steps = 3000
ep_steps = 5
batch_eps = 16
entanglement_target = 0
z_dim = 4
architecture = "dis_lib"
parameters = dict(repetitions=repetitions, n_sgd_steps=n_sgd_steps, ep_steps=ep_steps, batch_eps=batch_eps,
                  entanglement_target=entanglement_target, z_dim=z_dim)

# noinspection PyArgumentList
params = torch.FloatTensor([1, 1, 0.5, 0, 0])


def calc_entanglement(params):
    params = params.abs().pow(2)
    return params.sum() - params.max()


def produce_embeddings(images_factor_dataset, encoder_pytorch, device = "cpu"):
    latent = np.zeros(images_factor_dataset.factors_shape + (4,))
    for num_factor1, factor1 in enumerate(images_factor_dataset.factor_values_list[0]):
        for num_factor2, factor2 in enumerate(images_factor_dataset.factor_values_list[1]):
            print("Encoding image ({},{})".format(factor1, factor2))
            img = images_factor_dataset.images[num_factor1, num_factor2, ...]
            torch.from_numpy(img).permute(-1, 0, 1).float()
            z = encoder_pytorch(torch.from_numpy(img).permute(-1, 0, 1).float().to(device))  # .numpy()
            latent[num_factor1, num_factor2] = z.detach().to("cpu").numpy()
    return latent


calc_entanglement(params)

rotate_hueshift_arrow_params = {
    "data": "arrow",
    "arrow_size": 64,
    "n_hues": 64,
    "n_rotations": 64,
}

wrapped_pixel_params = {
    "data": "pixel",
    "height": 64,
    "width": 64,
    "step_size_vert": 1,
    "step_size_hor": 1,
    "square_size": 4
}

modelnet_params = {
    "dataset_filename": "modelnet_color_single_64_64.h5",
    "data": "modelnet_colors"
}
dataset_parameters_list = [("pixel4", wrapped_pixel_params), ("modelnet", modelnet_params), ("arrow", rotate_hueshift_arrow_params)]
# dataset_names = ["arrow", "modelnet", "pixel4"]
# dataset_names = ["pixel4", "modelnet", "arrow"]
# dataset_names = ["arrow"]
for name_dataset, dataset_parameters in dataset_parameters_list:

    print("LOADING DATASET {}".format(dataset_parameters["data"]))
    # Load images and create data enviroment
    images_dataset = data.data_loader.load_factor_data(root_path=PROJECT_PATH, **dataset_parameters)
    obs_env = data.data_environment.ImageWorld(images_dataset.images)

    # Save example images from the dataset
    plt.figure(figsize=(3, 3))
    gs1 = gridspec.GridSpec(3, 3)
    gs1.update(wspace=0.02, hspace=0.02)
    plt.grid(None)
    state = obs_env.reset()
    for i in range(9):
        ax = plt.subplot(gs1[i])
        ax.axis('off')
        ax.set_aspect('equal')
        if state.shape[-1] == 1:
            ax.imshow(state[:, :, 0])
        else:
            ax.imshow(state)
        display.display(plt.gcf())
        time.sleep(0.2)
        display.clear_output(wait=True)
        action = random.sample([0, 1, 2, 3], k=1)[0]
        #     action = 2
        # print(env.env.agent.body.position)
        state = obs_env.step(action)
    plt.savefig(dataset_parameters["data"] + "_env.png", bbox_inches='tight')

    obs = obs_env.reset().permute(-1, 0, 1).float()
    image_size = obs.shape[1:]

    # Neptune Experiment
    group = "TUe"
    api_token = neptune_config.API_KEY  # read api token from neptune config file
    upload_source_files = ["caselles_code.py"]  # OPTIONAL: save the source code used for the experiment
    neptune.init(project_qualified_name=group + "/sandbox", api_token=api_token)

    for repetition in reversed(range(start_repetition, end_repetition)):
        experiment_name = "caselles_" + dataset_parameters["data"] + "_" + str(repetition)
        with neptune.create_experiment(name=experiment_name, params=parameters,
                                       upload_source_files=upload_source_files):
            neptune.append_tag(architecture)
            print("TRAINING REPETITION {}".format(repetition))
            # DEFINE NETWORK FOR TRAINING
            lat_env = LatentWorld(dim=z_dim,
                                  n_actions=obs_env.action_space.n)
            if architecture == "teapot":
                architecture_module = teapot
            elif architecture == "dis_lib":
                architecture_module = dis_lib
            elif architecture == "vgg":
                architecture_module = vgg
            else:
                architecture_module = None
            decoder = architecture_module.Decoder(n_in=z_dim, n_channels=images_dataset.images.shape[-1])
            encoder = architecture_module.Encoder(n_out=z_dim, n_channels=images_dataset.images.shape[-1])
            encoder.to(device)
            decoder.to(device)
            print(encoder)
            print(decoder)
            # noinspection PyUnresolvedReferences
            optimizer_dec = optim.Adam(decoder.parameters(),
                                       lr=1e-2,
                                       weight_decay=0)

            # noinspection PyUnresolvedReferences
            optimizer_enc = optim.Adam(encoder.parameters(),
                                       lr=1e-2,
                                       weight_decay=0)

            optimizer_rep = optim.Adam(lat_env.get_representation_params(),
                                       lr=1e-2,
                                       weight_decay=0)

            losses = []
            entanglement = []

            # START TRAINING
            i = 0

            t_start = time.time()

            temp = 0
            print("START TRAINING")


            while i < n_sgd_steps:

                loss = torch.zeros(1)

                for _ in range(batch_eps):
                    t_ep = -1
                    while t_ep < ep_steps:
                        if t_ep == -1:
                            obs_x = obs_env.reset().permute(-1, 0, 1).float().to(device)
                            obs_z = lat_env.reset(encoder(obs_x))
                        else:
                            action = obs_env.action_space.sample().item()
                            obs_x = obs_env.step(action).permute(-1, 0, 1).float().to(device)
                            obs_z = lat_env.step(action)

                        t_ep += 1

                        obs_x_recon = decoder(obs_z.to(device))
                        loss += F.binary_cross_entropy(obs_x_recon, obs_x)

                loss /= (ep_steps * batch_eps)
                raw_loss = loss.item()
                neptune.log_metric('cross_entropy_loss', loss)
                reg_loss = sum([calc_entanglement(r.thetas) for r in lat_env.action_reps]) / 4

                loss += (reg_loss - entanglement_target).abs() * 1e-2

                # log complete loss
                neptune.log_metric('entanglement_loss', reg_loss)
                neptune.log_metric('batch_loss', loss)

                losses.append(raw_loss)
                entanglement.append(reg_loss.item())

                optimizer_dec.zero_grad()
                optimizer_enc.zero_grad()
                optimizer_rep.zero_grad()
                loss.to(device)
                loss.backward()
                optimizer_enc.step()
                optimizer_dec.step()
                optimizer_rep.step()

                # Remember to clear the cached action representations after we update the parameters!
                lat_env.clear_representations()

                i += 1

                if i % 10 == 0:
                    print("iter {} : loss={:.3f} : entanglement={:.2e} : last 10 iters in {:.3f}s".format(
                        i, raw_loss, reg_loss.item(), time.time() - t_start
                    ), end="\r" if i % 100 else "\n")
                    t_start = time.time()

            # Get plots
            fig, _ = plotting.plot_action_distribution(lat_env)
            neptune.log_image("plots", fig, image_name="action_dist")
            fig, _ = plotting.plot_reconstructions(obs_env, encoder, decoder, 8, device = device)
            neptune.log_image("plots", fig, image_name="reconstruction")

            filename = dataset_parameters["data"] + "_z_" + str(repetition) + ".npy"
            save_folder = os.path.join("/home/learning-group-structure/results_" + architecture,
                                       dataset_parameters["data"])
            os.makedirs(save_folder, exist_ok=True)
            save_path = os.path.join(save_folder, filename)

            # Save model
            model_save_path = os.path.join(save_folder, "models")
            os.makedirs(model_save_path, exist_ok=True)
            model_path_encoder = os.path.join(model_save_path,
                                              dataset_parameters["data"] + "_" + str(repetition) + "_encoder.pth")
            model_path_decoder = os.path.join(model_save_path,
                                              dataset_parameters["data"] + "_" + str(repetition) + "_decoder.pth")
            torch.save(encoder, model_path_encoder)
            torch.save(decoder, model_path_decoder)

            print("START CREATING EMBEDDINGS")
            latent_embeddings = produce_embeddings(images_dataset, encoder, device = device)
            angles = []
            num_factors = 64
            embeddings_flat = latent_embeddings.reshape((num_factors ** 2, 4))
            for factor in range(2):
                angles.append(2 * np.pi * np.array(range(num_factors)) / num_factors)
            factor_meshes = np.meshgrid(*angles, indexing="ij")
            factor_mesh = np.stack(factor_meshes, axis=-1)  # (n1, n2, n3, ..., n_n_factors ,n_factors)
            flat_factor_mesh = factor_mesh.reshape((num_factors ** 2, 2))
            colors_flat = yiq_embedding(flat_factor_mesh[:, 0], flat_factor_mesh[:, 1])
            fig, axes = plot_latent_dimension_combinations(embeddings_flat, colors_flat)
            neptune.log_image("plots", fig, image_name="embeddings")
            np.save(save_path, latent_embeddings)

            print("EMBEDDINGS SAVED")

            # Calculate our metric
            k_values = general_metric.create_combinations_k_values_range()
            lsbd_score, k_min = general_metric.calculate_metric_k_list(latent_embeddings, k_values)
            print("LSBD score", lsbd_score)

            print("k_min", k_min)
            print("LSBD score", lsbd_score)
            saving_icml = os.path.join("/home/ICML/results", name_dataset, "quessard", architecture,
                                       str(repetition))
            os.makedirs(saving_icml, exist_ok=True)

            np.save(os.path.join(saving_icml, "lsbd.npy"), lsbd_score)
            neptune.log_metric("LSBD", lsbd_score)
