import torch
import json
from torch.utils.data import DataLoader, Subset, ConcatDataset, Dataset
import sys
import os
sys.path.insert(1, os.path.dirname(os.getcwd()))
import datasets
import architectures as archs
import numpy as np
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as plt
from argparse import Namespace
import argparse
from PIL import Image
from tqdm import tqdm
from sklearn.manifold import TSNE
import umap

############ HELPFUL UTIL FOR JSON SERIALIZATION OF NUMPY ARRAYS ################
def default(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    raise TypeError('Not serializable')

########## MAIN ###############
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument('--filepath', type = str, required = True)
parser.add_argument('--job_id', type = str, required = True)
parser.add_argument('--split', type=str, default='combined', choices=['testing', 'evaluation', 'combined'])
parser.add_argument('--evaltype', type=str, default='discriminative')
parser.add_argument('--output_dir', type = str, default = os.path.join(os.getcwd(), "Dimensionality_Reduction"))
args = parser.parse_args()

PATH = os.path.join(args.filepath, "model_{}.pth.tar".format(args.job_id))
ld = torch.load(PATH)
opt = Namespace(**ld['opt'])

if "parade" in opt.method:
    opt.arch = "multifeature_resnet50" if "resnet" in opt.arch else "multifeature_bninception"
model      = archs.select(opt.arch, opt)
model = model.to(device)
model.load_state_dict(ld['model_state_dict'])

dataloaders = {}
dsets    = datasets.select(opt.dataset, opt, opt.source_path)

dataloaders['evaluation'] = DataLoader(dsets['evaluation'], num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)
dataloaders['testing']    = DataLoader(dsets['testing'],    num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)

dset = ConcatDataset([dsets["evaluation"], dsets["testing"]])
dset.image_list = dsets["evaluation"].image_list + dsets["testing"].image_list
dset.image_paths = dset.image_list
dset.image_dict = {key: dsets["evaluation"].image_dict.get(key, [])
               + [[x[0], len(dsets["evaluation"].image_list)+x[1]] for x in dsets["testing"].image_dict.get(key, [])]
               for key in set.union(set(list(dsets["evaluation"].image_dict.keys())), set(list(dsets["testing"].image_dict.keys())))}
if hasattr(dsets["evaluation"], "metadata") and hasattr(dsets["testing"], "metadata"):
    dset.metadata = pd.concat([dsets["evaluation"].metadata, dsets["testing"].metadata])

dataloaders["combined"] = DataLoader(dset, num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)

feature_dict = {}
target_dict = {}
for key, dataloader in dataloaders.items():
    image_paths     = np.array([x[0] for x in dataloader.dataset.image_list])
    _ = model.eval()

    ###
    features = []

    ###
    with torch.no_grad():
        target_labels = []
        final_iter = tqdm(dataloader, desc='Embedding Data...')
        image_paths= [x[0] for x in dataloader.dataset.image_list]
        for idx,inp in enumerate(final_iter):
            input_img,target = inp[1], inp[0]
            target_labels.extend(target.numpy().tolist())
            out = model(input_img.to(device))
            if isinstance(out, tuple): out, aux_f = out
            if isinstance(out, dict): out = out[args.evaltype]

            features.extend(out.cpu().detach().numpy().tolist())

        features = np.vstack(features).astype('float32')
        target_labels = np.hstack(target_labels).reshape(-1,1)

    feature_dict[key] = features
    target_dict[key] = target_labels

key = args.split
df = dataloaders[key].dataset.metadata
df = df.set_index('imagepath').loc[np.array([x[0] for x in dataloaders[key].dataset.image_list])].reset_index()
df['targets'] = target_dict[key]

tsne = TSNE(random_state=0, n_jobs=-1, n_components=2).fit_transform(feature_dict[key])

tsne_x = tsne[:,0]
tsne_y = tsne[:,1]

df['tsne_x'] = tsne_x
df['tsne_y'] = tsne_y

um = umap.UMAP(random_state=0, n_jobs=-1, n_components=2).fit_transform(feature_dict[key])

um_x = um[:,0]
um_y = um[:,1]

df['um_x'] = um_x
df['um_y'] = um_y

OUTPUT_PATH = os.path.join(args.output_dir, args.job_id)
os.makedirs(OUTPUT_PATH, exist_ok = True)
df.to_csv(os.path.join(OUTPUT_PATH, "results_{}.csv".format(args.evaltype)), sep=',', header=True, index=False)
with open(os.path.join(OUTPUT_PATH, "hparam.json"), 'w') as fp:
    if "parade" in opt.method:
        opt.method = "parade"
    if hasattr(opt, "device"):
        delattr(opt, "device")
    json.dump(vars(opt), fp, default=default)
