import os
import os.path as osp
import csv
import flickrapi
import requests
import shutil
import numpy as np
import xml.etree.ElementTree as ET
from tqdm import tqdm
from collections import defaultdict
from pprint import pprint


def add_element(etree, name, text):
    element = ET.Element(name)
    element.text = text
    etree.append(element)

    return etree


class Retriever:
    def __init__(self, args):
        self.labels = defaultdict(list)
        self.connection_potentials = {}
        self.nodes = set([])
        self.connected_nodes = set([])
        self.annotations = ET.Element("annotations")

        self.args = args
        self.disable = (not self.args.verbose) or (self.args.debug)

    def start(self):
        self.run()

        if self.args.outputpath:
            node_labels = []
            for lab, (_, ids) in enumerate(self.labels.items()):
                for idx in ids:
                    node_labels.append([idx, lab])

            node_label_file = open(osp.join(self.args.outputpath, "node_labels"), "w")
            csvwriter = csv.writer(node_label_file)
            csvwriter.writerows(node_labels)
            node_label_file.close()

            xml_path = self.args.outputpath.split('/')[-2] + ".xml"
            xml_file = open(osp.join(self.args.outputpath, "Annotations", xml_path), "wb")
            xml_file.write(ET.tostring(self.annotations))
            xml_file.close()

        if self.args.adjacency:
            self.process_adjacency()
        else:
            self.compute_edges()

        if self.args.statistics:
            isolated_nodes = self.nodes.difference(self.connected_nodes)
            print(f'Created {len(self.nodes)} nodes')
            print(f'{len(isolated_nodes)} unconnected')
            print(f'Proportion isolated: {len(isolated_nodes) * 100 / len(self.nodes)}')

    def run(self):
        raise NotImplementedError

    def compute_edges(self):
        edges = []
        token_connections = {}

        for idx, tokens in tqdm(self.connection_potentials.items(), desc="Finding matches for edges",
                                   disable=self.disable):
            potential_connections = {}

            for token in tokens:
                if token in token_connections:
                    count = token_connections[token]
                    count[idx] += 1

                    for match_id in count.keys():
                        if match_id in potential_connections:
                            potential_connections[match_id] += count[match_id]
                        else:
                            potential_connections[match_id] = count[match_id]
                else:
                    count = defaultdict(int, idx=1)
                    token_connections[token] = count

            matches = self.get_matches(potential_connections, idx)
            for match_id in matches:
                if match_id != idx:
                    edges.append([idx, match_id])

                    if self.args.undirected:
                        edges.append([match_id, idx])

                    if self.args.statistics:
                        self.connected_nodes.add(idx)
                        self.connected_nodes.add(match_id)

        if self.args.outputpath:
            edge_file = open(osp.join(self.args.outputpath, "edges"), "w")
            csvwriter = csv.writer(edge_file)
            csvwriter.writerows(edges)
            edge_file.close()

        if self.args.verbose:
            print("Found ", str(len(edges)), " edges")

    def process_adjacency(self):
        raise NotImplementedError

    def get_matches(self, potential_connections, idx):
        raise NotImplementedError


class WebFlickrRetriever(Retriever):
    API_KEY = os.environ['FLICKR_API_KEY']
    API_SECRET = os.environ['FLICKR_API_SECRET']

    def __init__(self, args):
        super().__init__(args)

        self.flickr = flickrapi.FlickrAPI(self.API_KEY, self.API_SECRET, format='etree')
        self.flickr_ids = []

        if args.outputpath:
            self.image_paths = osp.join(args.outputpath, "Image")
        else:
            raise Exception("No outputpath given")

        if not args.adjacency:
            raise Exception("No adjacency given")

    def run(self):
        if self.args.fromFiles:
            files = osp.join(self.args.inputpath, "Files")
        else:
            files = osp.join(self.args.inputpath, "flickrXml")

        if not osp.exists(files):
            raise Exception("No file {files} with xml data found")

        for filename in tqdm(os.listdir(files), desc="Processing XML data", disable=self.disable):
            if self.args.fromFiles:
                self.process_file(filename)
            else:
                self.process_annotation(files, filename)

    def process_file(self, filename):
        idx = str(filename.split('.')[:-1])
        self.flickr_ids.append(idx)

        if self.args.statistics:
            self.nodes.add(idx)

    @staticmethod
    def process_annotation(filename):
        file = open(filename)

        flickr = flickrapi.FlickrAPI(self.API_KEY,
                                     self.API_SECRET,
                                     format='etree')

        img_data = []
        tag_data = {}

        line = file.readline()
        block = ""
        for line in file:
            block += line

            if line == "</photo>\n":
                photo = ET.fromstring(block)
                block = ""

                idx = photo.get('id')
                print(idx)

                tags = []
                for tag in photo.find('tags').findall('tag'):
                    tags.append(tag.text)
                tag_data[idx] = tags

                try:
                    urls = flickr.photos.getSizes(api_key=self.API_KEY, photo_id=idx)
                except flickrapi.exceptions.FlickrError:
                    continue

                url = urls.find('sizes').find('size').get('source')
                img_data.append({'idx':idx,
                                 'url':url})

        return img_data, tag_data

    def process(self, img_data, idx):
        img_path = osp.join(self.args.outputpath, "JPEGImages", f"{idx}.jpg")

        img = np.asarray(bytearray(img_data), dtype="uint8")
        img = cv2.imdecode(img, cv2.IMREAD_COLOR)
        img = cv2.resize(img, (self.args.size, self.args.size))
        cv2.imwrite(img_path, img)

        data = np.array(img).flatten()
        data.tofile(osp.join(self.args.outputpath, "Files", f'{idx}.csv'), sep=',')

    def process_adjacency(self):
        adjacency = np.genfromtxt(self.args.adjacency, delimiter=self.args.sep, dtype=np.uint32)

        present_ids = np.array(self.flickr_ids, dtype=np.uint32)
        valid_sources = adjacency[np.in1d(adjacency[:, 0], present_ids)]
        valid_edges = valid_sources[np.in1d(valid_sources[:, 1], present_ids)]

        if self.args.outputpath:
            valid_edges.tofile(osp.join(self.args.outputpath, "edges"), sep=',')

        if self.args.statistics:
            self.connected_nodes = set(np.ravel(valid_edges).tolist())


    def get_matches(self, potential_connections, idx):
        raise NotImplementedError("Use the [-a] adjacency file option")


class VOCCORetriever(Retriever):
    API_KEY = os.environ['FLICKR_API_KEY']
    API_SECRET = os.environ['FLICKR_API_SECRET']

    def __init__(self, args):
        super().__init__(args)

        self.flickr = flickrapi.FlickrAPI(self.API_KEY, self.API_SECRET, format='etree')
        self.idx_to_flickr = {}
        self.flickr_to_idx = defaultdict(list)

    def run(self):
        annotations = osp.join(self.args.inputpath, "Annotations")

        assert os.path.exists(annotations)

        for filename in tqdm(os.listdir(annotations), desc="Processing XML data", disable=self.disable):
            with open(osp.join(annotations, filename)) as file:
                root = ET.fromstring(file.read())
                source = root.find('source')
                flickr_id = source.find('flickrid').text

                if self.args.adjacency:
                    token_string = ""
                else:
                    token_string = self.get_token_string(flickr_id)

                if token_string:
                    img_suffix = root.find('filename').text

                    if self.args.statistics and not self.args.process:
                        img = None
                    else:
                        img_file = osp.join(self.args.inputpath, "JPEGImages", img_suffix)
                        img = cv2.imread(img_file, cv2.IMREAD_COLOR)

                        if self.args.debug:
                            print(img.shape)

                    img_data = (flickr_id, img_suffix, token_string)

                    offset = 1

                    for obj in root.findall('object'):
                        offset = self.process(obj, img, img_data, offset)

                        for part in obj.findall('part'):
                            offset = self.process(part, img, img_data, offset)

    def process(self, element, img, img_data, offset):
        flickr_id, _, token_string = img_data
        label, img_path, new_offset = self.get_annotation(element, img_data, offset)
        idx = int(img_path.split('.')[0])

        if self.args.process:
            x_data = self.get_data(element, img, img_path).tolist()

            node_attr_file = open(osp.join(self.args.outputpath, f'Files/{idx}'), "w")
            csvwriter = csv.writer(node_attr_file)
            csvwriter.writerow(x_data)
            node_attr_file.close()

        self.labels[label].append(idx)
        self.idx_to_flickr[idx] = flickr_id
        self.flickr_to_idx[flickr_id].append(idx)
        self.connection_potentials[idx] = token_string.split()

        if self.args.statistics:
            self.nodes.add(idx)

        return offset

    def get_data(self, element, img, img_path):
        xmin = int(float(element.find('bndbox').find('xmin').text))
        ymin = int(float(element.find('bndbox').find('ymin').text))
        xmax = int(float(element.find('bndbox').find('xmax').text))
        ymax = int(float(element.find('bndbox').find('ymax').text))

        cropped_img = img[ymin:ymax, xmin:xmax]
        new_img = cv2.resize(cropped_img, (self.args.size, self.args.size))
        cv2.imwrite(osp.join(self.args.outputpath, "JPEGImages", img_path), new_img)
        x_data = np.array(new_img)

        return x_data.flatten()

    def get_annotation(self, element, img_data, offset):
        flickr_id, img_suffix, comments = img_data

        data = ET.Element("annotation")

        img_split = img_suffix[2:]
        new_img_suffix = str(offset).zfill(2) + img_split

        data = add_element(data, 'id', new_img_suffix.split('.')[0])
        data = add_element(data, 'img', new_img_suffix)
        data = add_element(data, 'label', element.find('name').text)
        data = add_element(data, 'flickr_id', flickr_id)
        data = add_element(data, 'tokens', comments)
        self.annotations.append(data)

        offset += 1

        return element.find('name').text, new_img_suffix, offset

    def get_matches(self, potential_connections, idx):
        same_img = dict(filter(lambda val: val[1] == self.idx_to_flickr[idx], self.idx_to_flickr.items()))
        same_comment = dict(filter(lambda val: val[1] >= 10, potential_connections.items()))
        matches = set(same_img.keys()).union(set(same_comment.keys()))
        return matches

    def process_adjacency(self):
        adjacency = np.genfromtxt(self.args.adjacency, delimiter=self.args.sep, dtype=np.uint32)

        present_ids = np.array(list(self.flickr_to_idx.keys()), dtype=np.uint32)
        valid_sources = adjacency[np.in1d(adjacency[:, 0], present_ids)]
        valid_edges = valid_sources[np.in1d(valid_sources[:, 1], present_ids)]

        edges = np.empty((0, 2))
        for edge in valid_edges:
            sources = np.array(self.flickr_to_idx[str(edge[0])]).reshape((-1, 1))
            targets = np.array(self.flickr_to_idx[str(edge[1])]).reshape((-1, 1))
            new_edges = np.column_stack([np.tile(sources, (len(targets), 1)),
                                         np.repeat(targets, len(sources), axis=0)])
            edges = np.append(edges, new_edges)

        if self.args.outputpath:
            edges.tofile(osp.join(self.args.outputpath, "edges"), sep=',')

        if self.args.statistics:
            self.connected_nodes = set(np.ravel(edges).tolist())

    def get_token_string(self, flickr_id):
        token_string = None
        comments = None
        comment_tokens = []

        try:
            comments = self.flickr.photos.comments.getList(photo_id=flickr_id)

            if self.args.debug:
                print("Found photo")

        except Exception as e:
            if self.args.debug:
                print(e)

        if comments:
            for comment in list(comments[0]):
                comment_tokens += comment.text.split()

            token_string = ' '.join(comment_tokens)

        return token_string
