'''Test script for experiments in paper Sec. 4.2, Supplement Sec. 3, reconstruction from laplacian.
'''

# Enable import from parent package
import os
import sys

import pandas as pd

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
import naisr.modules

import naisr_meshing
import naisr
import naisr.workspace as ws
import argparse
import torch.utils.data as data_utils
from utils import cond_mkdir
from naisr import *
from visualizer import *


dict_list_cov = {}
dict_list_cov['Airway'] = {'age': np.linspace(-2.0, 2.0, 7), 'weight': np.linspace(-2., 4., 7)}
dict_list_cov['starman'] = {'cov_1': np.linspace(-1., 1., 7), 'cov_2': np.linspace(-1., 1., 7)}
dict_list_cov['ADNI'] = {'age': np.linspace(-3.0, 3.0, 7), 'AD': np.linspace(-1., 2., 7)}




def extract_latent_vector(latent_vectors, idx):
    average_latent_code = torch.mean(latent_vectors, dim=-2)[None, None, :]
    if idx == 'mean':
        return average_latent_code
    elif idx == 'zero':
        return torch.zeros_like(average_latent_code)
    else:
        try:
            return latent_vectors[int(idx)][None, None, :]
        except:
            print("Wrong index of latent vectors. Return average latent vector")
            return average_latent_code

def get_template_idx():

    return


if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Testing a DeepSDF autodecoder")

    arg_parser.add_argument(
        "--networksetting",
        "-e",
        dest="networksetting",
        default='examples/starman/naigsr.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )

    arg_parser.add_argument(
        "--backbone",
        "-b",
        dest="backbone",
        default='siren',
        help="mlp or siren",
    )

    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    arg_parser.add_argument(
        "--idx",
        "-i",
        dest="index",
        default='mean',
        type=str,
        help="shape ellipsoid or torus",
    )


    args = arg_parser.parse_args()
    specs = ws.load_experiment_specifications(args.networksetting)

    '''
    read network setting and IO settings
    '''

    backbone = args.backbone
    experiment_name = specs["ExperimentName"]
    print(experiment_name)
    template_attributes = specs["TemplateAttributes"]
    attributes = specs["Attributes"]
    split_file = specs["Split"]
    num_samp_per_scene = specs["SamplesPerScene"]
    device = specs['Device']
    latent_size = specs["CodeLength"]
    root_path = os.path.join(specs['LoggingRoot'], specs['ExperimentName'])
    cond_mkdir(root_path)
    '''
    load dataset
    '''
    data_source = specs["DataSource"]

    # load model
    latent_vectors = ws.load_latent_vectors(root_path, 'latest', torch.device(device)).to(device)
    average_latent_code = torch.zeros_like(torch.mean(latent_vectors, dim=-2)[None, None, :])


    # loading dataset
    shapetype = specs["Class"]
    filename_dataset = specs["DataSource"]
    num_dim = specs["InFeatures"]

    if shapetype == 'Airway':
        cases = naisr.get_airway_ids(specs["Split"], split='test')
        training_cases = naisr.get_airway_ids(specs["Split"], split='train')
        load_one_case = naisr.get_airway_data_for_id
        df_data = pd.read_csv(filename_dataset)
        #list_patient_scans = naisr.get_airways_for_transport(specs["DataSource"], specs["Split"], split='test_multiple')
    elif shapetype == 'starman':
        cases = naisr.get_starman_ids(filename_dataset, 'test')
        training_cases = naisr.get_starman_ids(filename_dataset, 'train')
        load_one_case = naisr.get_starman_data_for_id
        df_data = pd.read_csv(filename_dataset['test'])
        #list_patient_scans = naisr.get_starmans_for_transport(specs["DataSource"], split='test')
    elif shapetype == 'ADNI':
        cases = naisr.get_adni_ids(specs["Split"], split='test')
        training_cases = naisr.get_adni_ids(specs["Split"], split='train')
        load_one_case = naisr.get_adni_data_for_id
        df_data = pd.read_csv(filename_dataset)
        #list_patient_scans = naisr.get_adni_for_transport(specs["DataSource"], specs["Split"], split='test_multiple')



    # load model    # loading dataset
    #     shapetype = specs["Class"]
    #     filename_dataset = specs["DataSource"]
    #     num_dim = specs["InFeatures"]
    model = eval(specs['Network'])(
        template_attributes=specs['TemplateAttributes'],
        in_features=specs['InFeatures'],
        hidden_features=specs['HiddenFeatures'],
        hidden_layers=specs['HidenLayers'],
        out_features=specs['OutFeatures'],
        device=specs['Device'],
        backbone=specs['Backbone'],
        outermost_linear=False,
        pos_enc=specs['PosEnc'],
        latent_size=specs["CodeLength"])

    checkpoint_path = os.path.join(root_path, ws.model_params_subdir, args.checkpoint + '.pth')
    print(checkpoint_path)
    model.load_state_dict(torch.load(checkpoint_path, map_location= torch.device(device))["model_state_dict"])
    model.to(specs['Device'])
    model.eval()


    '''
    evolution
    '''
    '''
    for age in np.linspace(-1, 1, 5):
        dict_of_evolution[age] = {}
        dict_text[age]  = {}
        dict_color[age]  = {}

        for weight in np.linspace(-1, 1, 5):
            logging.info("evolving {}{}".format(age, weight))
            '''
    # evaluate testing

    savepath_evo = os.path.join(root_path, 'TemplateShapeMatrix')
    cond_mkdir(savepath_evo)
    savepath_evo_type = os.path.join(savepath_evo, 'average')
    cond_mkdir(savepath_evo_type)

    dict_of_evolution = {}
    dict_text = {}
    dict_color = {}
    figure_name = 'shapematrix2'
    #average_latent_code = torch.mean(latent_vectors[[3, 4]], dim=0)[None, None, :]  # [indices]

    #
    covariante_names = specs['Attributes'] #[specs['Attributes'][0], specs['Attributes'][1]]
    for cov_1 in dict_list_cov[shapetype][covariante_names[0]]: #np.linspace(-1., 2., 7):
        dict_of_evolution[cov_1] = {}
        dict_text[cov_1] = {}
        dict_color[cov_1] = {}

        for cov_2 in dict_list_cov[shapetype][covariante_names[1]]:

            cov_2 = -cov_2
            logging.info("evolving {}{}".format(cov_1, cov_2))

            attributes = specs["TemplateAttributes"].copy()
            attributes[covariante_names[0]] = cov_1 #np.array([cov_1])[None, :]
            attributes[covariante_names[1]] = cov_2 #np.array([cov_2])[None, :]

            attributes = {key: torch.from_numpy(np.array([value])[None, :]).to(device).float()[[0], ...] for key, value in attributes.items()}

            savedir = os.path.join(savepath_evo_type, 'cov1_' + str(cov_1) + '_cov2_' + str(cov_2))
            cond_mkdir(savedir)
            savepath = naisr_meshing.create_mesh_reconstruction(shapetype)(model,
                                                                average_latent_code,
                                                                attributes,
                                                                           {},
                                                                savedir,
                                                                output_type='model_out',
                                                                dim=num_dim,
                                                                shapetype=shapetype,
                                                                N=256,
                                                                device=specs['Device'],
                                                                EVALUATE=False,
                                                                MAKE_GT=False,
                                                                MAKE_TEMPLATE=True)
            #
            # current_color = naisr_meshing.revert_points_to_template(model,
            #                                                         average_latent_code,
            #                                                         attributes,
            #                                                         savedir,
            #                                                         device)


            # current_color = naisr_meshing.calculate_normal_map(shapetype,model,
            #                                                         average_latent_code,
            #                                                         attributes,
            #                                                         savedir,
            #                                                         device)


            if shapetype == 'starman':
                dict_of_evolution[cov_1][cov_2] = savepath #os.path.join(savedir, 'surface.vtk')

            else:
                dict_of_evolution[cov_1][cov_2] = savepath #os.path.join(savedir, 'surface.stl')
            #dict_color[cov_1][cov_2] = current_color
            dict_text[cov_1][cov_2] = attributes

            plotter_evolution_shapematrix(shapetype)(dict_of_evolution,
                                      savepath_evo_type,
                                      dict_text0=dict_text,
                                      dict_colors0=None)
                                      #dict_colors0=dict_color)



