import os.path
from argparse import ArgumentParser
import json
from hashlib import md5
from time import time
import pandas as pd
import numpy as np
import git
from datetime import datetime

import os
import json
import pandas as pd
import torch
import shutil


from collections.abc import MutableMapping

from argparse import ArgumentParser

class FlaggedArgumentParser(ArgumentParser):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._flagged_parser = ArgumentParser(*args, **kwargs)

    def add_argument(self, *args, **kwargs):
        if kwargs.pop('flag', False):
            self._flagged_parser.add_argument(*args, **kwargs)
        return super().add_argument(*args, **kwargs)

    def parse_args(self, args=None, namespace=None, only_flagged=False):
        if only_flagged:
            return self._flagged_parser.parse_known_args(args, namespace)[0]
        return super().parse_args(args, namespace)


def _flatten_dict_gen(d, parent_key, sep):
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, MutableMapping):
            yield from flatten_dict(v, new_key, sep=sep).items()
        else:
            yield new_key, v


def flatten_dict(d: MutableMapping, parent_key: str = '', sep: str = '.'):
    """
    Flatten a nested dictionary.
    """
    return dict(_flatten_dict_gen(d, parent_key, sep))


def read_experiments(output_path):
    """
    Read experiments from given location.

    :param output_path: p
    :return: Dataframe containing all experiments found in {output_path}
    """
    data = []
    path, directories, _ = next(os.walk(output_path))
    for directory in directories:
        meta_path = os.path.join(path, directory, 'meta.json')
        if os.path.exists(meta_path):
            with open(meta_path) as f:
                meta = json.load(f)

            if not 'files' in meta:
                _, _, files = next(os.walk(os.path.join(path, directory)))
                files.remove('meta.json')
                meta['files'] = files
            meta['path'] = path
            data.append(flatten_dict(meta))

    return pd.DataFrame(data).set_index('hash')


def read_dataframes(df_experiments):
    """
    Read dataframes for each experiment.
    :param df_experiments:
    :return:
    """
    df_dict = {}
    for idx, row in df_experiments.iterrows():
        files = row['files']
        path = row['path']

        df_dict[idx] = dict()
        for file in files:
            df_experiments = pd.read_csv(os.path.join(path, idx, file))
            df_dict[idx][file] = df_experiments
    return df_dict


class Tracker:
    """
    A class to setup and track iterative experiments.
    """
    def __init__(self, config, path, index=None, delete_existing=False, delete_incomplete=True):
        self.files = {}
        self.data  = {}
        self.meta  = {}
        self.hash = self._compute_hash(config)
        self.path = os.path.join(path, self.hash)
        self.index = index
        if index is not None:
            self.path = os.path.join(self.path, str(index))

        if os.path.exists(self.path):
            if delete_incomplete :
                if os.path.exists(os.path.join(self.path, 'meta.json')):
                    with open(os.path.join(self.path, 'meta.json')) as f:
                        meta = json.load(f)
                    if not meta.get('complete', False):
                        print(f"Deleting incomplete experiment at {self.path}")
                        shutil.rmtree(self.path)
                else:
                    print(f"Deleting incomplete experiment at {self.path}")
                    shutil.rmtree(self.path)
            elif delete_existing:
                print(f"Deleting existing experiment at {self.path}")
                shutil.rmtree(self.path)
            else:
                raise FileExistsError(f"Experiment ({self.path}) already exists.")
        
        os.makedirs(self.path)
        
        self.meta['config'] = config
        self.meta['hash'] = self.hash
        self.meta['index'] = index
        self.meta['complete'] = False
        # save meta data
        with open(os.path.join(self.path, 'meta.json'), 'w') as f:
            json.dump(self.meta, f, indent=2)

    def log(self, **kwargs):
        iteration = kwargs.pop('iteration')
        if iteration not in self.data:
            self.data[iteration] = {}
        for key, value in kwargs.items():
            if isinstance(value, torch.Tensor):
                value = value.item()
            elif isinstance(value, np.ndarray):
                value = value.item()
            
            self.data[iteration][key] = value

    def log_meta(self, **kwargs):
        for key, value in kwargs.items():
            self.meta[key] = value

    def log_file(self, name, data, iteration=None):
        if iteration:
            name = f"{iteration}_{name}"

        if isinstance(data, torch.Tensor):
            data = data.detach().cpu().numpy()
        self.files[name] = data

    def save(self, complete=False):
        # save data
        np.savez(os.path.join(self.path, 'files.npz'), **self.files)

        # save logs
        df = pd.DataFrame.from_dict(self.data, orient='index')
        df.index.rename('iteration', inplace=True)
        df.to_csv(os.path.join(self.path, 'data.csv'))

        # save meta data
        if complete:
            self.meta['complete'] = True
        with open(os.path.join(self.path, 'meta.json'), 'w') as f:
            json.dump(self.meta, f, indent=2)
        if complete:
            print(f"Saved files to {self.path}")

    # def _parse(self):
    #     parser = ArgumentParser(description=self._description)
    #     for name, values in self._parameters.items():
    #         # kwargs for argparse .add_argument()
    #         kwargs = dict(type=values['type'], help=values['help'], choices=values['choices'],
    #                             default=values['default'], action=values['action'])
    #         # drop keys with value None
    #         kwargs = {k : v for k,v in kwargs.items() if v is not None}
    #         parser.add_argument(f"--{name}", **kwargs)
    #     return vars(parser.parse_args())

    def _compute_hash(self, config):
        # create a hash of the parameter configuration
        return md5(json.dumps(config, sort_keys=True).encode()).hexdigest()

    # def _git_status(self, args):
    #     try:
    #         repo = git.Repo(search_parent_directories=True, path=args.get('code_path'))
    #     except git.InvalidGitRepositoryError:
    #         return None
    #     return dict(
    #         sha=repo.head.object.hexsha,
    #         branch=repo.active_branch.name,
    #         modified=[a.a_path for a in repo.index.diff(None)] + [a.a_path for a in repo.index.diff('Head')],
    #         untracked=repo.untracked_files
    #     )

    def run(self, fct):
        args = self._parse()

        metadata = dict()
        metadata['timestamp'] = datetime.now().strftime("%d.%m.%Y %H:%M:%S")
        metadata['hash'] = self._compute_hash(args)

        # add git information
        git_status = self._git_status(args)
        if git_status is None:
            print("git repository not found")
        else:
            metadata['git'] = git_status

        metadata['parameters'] = args
        if self._description:
            metadata['description'] = self._description

        # run function, time it
        t0 = time()
        extra_metadata, df_dict = fct(**args)
        metadata['seconds'] = f"{time() - t0:.2f}"

        # additional meta data
        if extra_metadata is not None:
            metadata.update(extra_metadata)

        path = os.path.join(args['output_path'], metadata['hash'])

        if args['output_path']:
            if not os.path.exists(path):
                os.makedirs(path)

            # save dataframes
            if df_dict is not None:
                for file, df in df_dict.items():
                    df.to_csv(os.path.join(path, file))

                metadata['files'] = [*df_dict.keys()]

            # save meta data
            with open(os.path.join(path, 'meta.json'), 'w') as f:
                json.dump(metadata, f, indent=2)

            print(f"Experiment completed in {metadata['seconds']} seconds. Saved to {path}.")

        return metadata, df_dict

    # def add_parameter(self, name, type=None, default=None, required=True, description="", hash=True, help=None,
    #                   choices=None, action=None):
    #     """
    #     Add input parameter. Follows the argparse syntax.
    #     """
    #     self._parameters[name] = dict(type=type,
    #                                   default=default,
    #                                   required=required,
    #                                   description=description,
    #                                   hash=hash,
    #                                   help=help,
    #                                   choices=choices,
    #                                   action=action)
