import argparse
import requests
import json
import csv
import sys
import os
import pickle as pkl
import pandas as pd
import numpy as np
import os.path as osp
from PIL import Image
from tqdm import tqdm
from multiprocessing import Pool
from retrievers import WebFlickrRetriever, NUSWideRetriever


def run_p(datum):
    url = RETRIEVER.get_url(datum)

    if url:
        try:
            req = requests.get(url, stream=True)
            img = Image.open(req.raw).convert('RGB')
            cropped = img.resize((SIZE, SIZE))

            filename = datum + ".jpg"
            cropped.save(osp.abspath(osp.join(IMAGEOUTPUT, filename)))
        except:
            pass


def get_nodes(folder):
    nodes = set([])

    for filename in os.listdir(osp.abspath(folder)):
        nodes.add(int(filename.split('.')[0]))

    return list(nodes)


def neighbours(node, adjacency):
    source_neighbours = adjacency[adjacency[:, 0] == node][:, 1].tolist()
    target_neighbours = adjacency[adjacency[:, 1] == node][:, 0].tolist()
    return list(set(source_neighbours + target_neighbours))


def DFS(temp, node, visited, adjacency):
    visited[node] = True

    temp.append(node)

    for node in neighbours(node, adjacency):
        if visited[node] == False:
            temp = DFS(temp, node, visited, adjacency)

    return temp


def connected_components(args):
    visited = {}
    max_connected = 0
    adjacency = np.genfromtxt(osp.abspath(args.xmlinput))
    nodes = get_nodes(args.imagefolder)
    total_nodes = len(nodes)

    valid_sources = adjacency[np.in1d(adjacency[:, 0], nodes)]
    valid_adjacency = valid_sources[np.in1d(valid_sources[:, 1], nodes)]

    for node in nodes:
        visited[node] = False

    for node in tqdm(nodes, desc="Checking all nodes"):
        if visited[node] == False:
            temp = []
            connected = DFS(temp, node, visited, valid_adjacency)
            if len(connected) > max_connected:
                max_connected = len(connected)

            if max_connected == total_nodes:
                break

            if max_connected > total_nodes:
                raise Exception("subgraph found larger than input graph")

    print(f'{total_nodes} downloaded and present in the graph')
    print(f'Of which {max_connected / total_nodes} are connected together in one graph')


def hist_embedding(args):
    valid_ids = set([int(line[:-1]) for line in open(args.valid_ids)])
    image_files = [filename 
                    for filename in os.listdir(args.imagefolder) 
                        if int(filename.split('.')[0]) in valid_ids
                   ]

    for filename in image_files:
        img = Image.open(osp.join(args.imagefolder, filename))
        r, g, b = img.split()
        hist_r = r.histogram()
        hist_g = g.histogram()
        hist_b = b.histogram()
        features = hist_r + hist_g + hist_b
        features = np.array(features)
        features = features / np.linalg.norm(features)
        np.save(osp.join(args.attroutput, filename.split('.')[0]), features, allow_pickle=False)


def process(args):
    tag_file = open(osp.abspath(args.tagfile), 'r')
    tags = json.load(tag_file)
    tag_file.close()

    node_list = [int(filename.split('.')[0]) for filename in os.listdir(args.imagefolder)]
    available_nodes = np.array(node_list)
    adjacency = np.genfromtxt(args.adjacency, delimiter=args.separator, dtype=np.uint32)
    sources = adjacency[:, 0].tolist()
    targets = adjacency[:, 1].tolist()

    print(f'There are originally {len(adjacency)} edges')
    print(f'There are {len(set(node_list))} downloaded nodes')
    print(f'There are {len(set(sources + targets))} adjacency nodes')

    valid_sources = adjacency[np.in1d(adjacency[:, 0], available_nodes)]
    valid_targets = adjacency[np.in1d(adjacency[:, 1], available_nodes)]

    print(f'Test that {len(set(valid_sources[:,0].tolist() + valid_targets[:,1].tolist()))} = {len(set(node_list))}')

    sources_df = pd.DataFrame(valid_sources, columns=['sources', 'join'])
    targets_df = pd.DataFrame(valid_targets, columns=['join', 'targets'])
    adjacency_df = sources_df.merge(targets_df, on='join', how='inner')[['sources', 'targets']]
    valid_adjacency = adjacency_df.to_numpy()

    sources = valid_adjacency[:, 0].tolist()
    targets = valid_adjacency[:, 1].tolist()
    present_nodes = set(sources + targets)

    print(f'There are now {len(valid_adjacency)} edges')
    print(f'There are now {len(present_nodes)} present nodes')

    labels =[]
    for node in tqdm(present_nodes, desc="Saving files"):
        idx = str(node)

        img = Image.open(osp.join(args.imagefolder, idx + '.jpg'))
        x_data = np.ravel(np.asarray(img))

        node_attr_file = open(osp.abspath(osp.join(args.attroutput, idx)), 'w')
        csvwriter = csv.writer(node_attr_file)
        csvwriter.writerow(x_data)
        node_attr_file.close()

        tag = tags[idx]
        if len(tag) < 1:
            labels.append([idx, 'null'])
        elif isinstance(tag, list):
            labels.append([idx, tag[0]])
        else:
            labels.append([idx, tag])

    node_label_file = open(osp.abspath(osp.join(args.dataoutput, "node_labels")), 'w')
    csvwriter = csv.writer(node_label_file)
    csvwriter.writerows(labels)
    node_label_file.close()

    reverse_adjacency = np.flip(valid_adjacency, axis=1)
    undirected_adjacency = np.vstack((valid_adjacency, reverse_adjacency))
    np.savetxt(osp.abspath(osp.join(args.dataoutput, "edges.csv")), undirected_adjacency, fmt='%d', delimiter=',')


def embedding(args):
    pkl_file = open(osp.abspath(args.pklinput), "rb")
    embeddings = pkl.load(pkl_file)

    files = embeddings[0]
    attrs = embeddings[1]

    for (filename, attribute) in tqdm(zip(files, attrs), desc="saving embedding attributes"):
        idx = filename.split('.')[0]
        node_attr_file = open(osp.abspath(osp.join(args.attroutput, idx)), 'w')
        csvwriter = csv.writer(node_attr_file)
        csvwriter.writerow(attribute)
        node_attr_file.close()


def main(args):
    if args.download:
        data, tags = RETRIEVER.process_annotation(args.xmlinput)

        print(len(data))
        print(len(tags))

        with open(osp.abspath(args.tagfile), 'w+') as file:
            try:
                file_data = json.load(file)
                file_data.update(tags)
            except:
                file_data = tags

            print(len(file_data))
            file.seek(0)
            json.dump(file_data, file)

        with Pool(processes=32) as pool:
            pool.map(run_p, data)

    if args.statistics:
        connected_components(args)

    if args.process:
        process(args)

    if args.embedding:
#         embedding(args)
        hist_embedding(args)


if __name__=='__main__':
    parser = argparse.ArgumentParser(
            description='parameterised slurm runs')

    parser.add_argument('--statistics', action='store_true', help='Calculate statistics from input file')
    parser.add_argument('--process', action='store_true', help='Create pytorch geometric loader files')
    parser.add_argument('--download', action='store_true', help='Download and resize images from xml files')
    parser.add_argument('--embedding', action='store_true', help='Extract attribute files from embedding pickle files')
    parser.add_argument('--pklinput', required='--embedding' in sys.argv, help='file where pickle data exists')
    parser.add_argument('--adjacency', required='--process' in sys.argv, help='location of adjacency file data')
    parser.add_argument('--separator', required='--adjacency' in sys.argv)
    parser.add_argument('--retriever', required='--download' in sys.argv, choices=["WebFlickr", "NUSWide"], help='retriever to use')
    parser.add_argument('--xmlinput', required='--download' in sys.argv, help='file where xml data exists')
    parser.add_argument('--attroutput', required='--process' in sys.argv or '--embedding' in sys.argv, help='folder where attribute files to be stored')
    parser.add_argument('--dataoutput', required='--process' in sys.argv, help='folder where dataset files stored')
    parser.add_argument('--imagefolder', required=True, help='folder for images to be stored')
    parser.add_argument('--tagfile', required=True in sys.argv, help='file for tags to be stored')
    parser.add_argument('--size', required='--download' in sys.argv, help='image resize')
    parser.add_argument('--valid_ids', required='--embedding' in sys.argv, help='which image ids are to be embedded')
    args = parser.parse_args()

    if args.download:
        if args.retriever == "WebFlickr":
            RETRIEVER = WebFlickrRetriever
        elif args.retriever == "NUSWide":
            RETRIEVER = NUSWideRetriever
        else:
            raise Exception("Retriever unknown")

        SIZE = int(args.size)
        IMAGEOUTPUT = args.imagefolder

    main(args)

