import json
import os
from argparse import ArgumentParser
from hashlib import md5
from time import time
import numpy as np

import torch
import wandb
from chip.datasets import TomogramDataset
from chip.models.forward_models import fourier_filtering
from chip.models.iterative_model import TomographicReconstruction
from torch.utils.data import Subset

from chip.datasets.superres_dataset import SuperresolutionDS
from chip.utils import create_circle_filter
from chip.utils.metrics import get_metrics
from tqdm.auto import tqdm
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur


def check_rotate_type(value):
    if value == 'random':
        return value
    if value == 'None':
        return None
    else:
        return int(value)


def active_learning_benchmark(config, exclude_keys_from_group=None, generator=None):
    # exclude_keys = ['dataset.path', 'dataset.index', 'dataset']
    #
    # if exclude_keys_from_group is not None:
    #     exclude_keys += exclude_keys_from_group
    # group_hash = md5(json.dumps({k: v for k,v in config.items() if not k in exclude_keys}, sort_keys=True).encode()).hexdigest()
    # group_str = ', '.join([f"{k}={v}" for k,v in config.items() if not k in exclude_keys])
    # print(f"Config: {group_str}")

    # load dataset
    dataset = get_dataset(config)

    if config['dataset.index'] is not None:
        dataset = Subset(dataset, [config['dataset.index']])

    wandb_project = f"benchmark-{config['dataset']}"
    if config['wandb.debug']:
        wandb_project += '-debug'

    # check if wandb project exists, if not, create it

    # create project if it doesn't exist
    api = wandb.Api()
    if not wandb_project in [p.name for p in api.projects(config['wandb.entity'])]:
        print(f"Creating project {wandb_project} for entity {config['wandb.entity']}")
        api.create_project(wandb_project, entity=config['wandb.entity'])

    for prior_data, target, idx in tqdm(dataset, desc="Dataset", position=0):
        if config['wandb.skip_completed'] and not config['wandb.disabled']:
            print("Fetching completed experiments from wandb")
            filters = {k: v for k, v in config.items() if not k.startswith('wandb')}
            filters['completed'] = True
            filters['idx'] = idx

            filters = {f'config.{k}' + ('.value' if v is not None else ""): v for k, v in filters.items()}
            # print(filters)
            # for k, v in filters.items():
            #     print(f"{k}={type(v)}:{v}")
            runs = api.runs(
                path=f"{config['wandb.entity']}/{wandb_project}",
                filters=filters
            )
            if len(runs) > 0:
                print("Skipping", filters)
                continue
        # else:
        #     print("Wandb disabled, skipping checking for completed experiments.")

        # item_hash = md5(json.dumps(meta, sort_keys=True).encode()).hexdigest()
        # if item_hash in completed:
        #     print(f"Skipping {meta} because it is already completed.")
        #     continue

        # initialize wandb
        run = wandb.init(project=wandb_project,
                         entity=config['wandb.entity'],
                         config=config,
                         tags=[config['wandb.tag']] if config['wandb.tag'] else None,
                         mode='disabled' if config['wandb.disabled'] else 'online')

        wandb.config.update({
            'completed': False,
            'idx': idx,
            'last_iteration': 0,
            'last_step': 0,
        })

        wandb.config.update({'_t0': time()}, allow_val_change=True)

        yield prior_data[0], target[0]

        wandb.config.update({'completed': True}, allow_val_change=True)
        run.finish()


def log_iteration(selected_angles, reconstruction, target, commit=True):
    """
    Log active learning iteration on wandb.
    """
    run = wandb.run
    assert run is not None, "Must initialize wandb run."
    if not len(selected_angles) == wandb.run.config['experiment.batch_size']:
        print("Number of selected angles does not match config batch size.")

    iteration = run.config['last_iteration'] + 1
    step = run.config['last_step'] + len(selected_angles)

    log_data = {'selected_angles': selected_angles,
                'selected_batch_size': len(selected_angles),
                'reconstruction': reconstruction,
                'iteration': iteration,
                'step': step,
                'seconds': time() - run.config['_t0'],
                **get_metrics(target, reconstruction)}

    if len(selected_angles) == 1:
        log_data['selected_angle'] = selected_angles[0]

    wandb.config.update({'last_iteration': iteration,
                         'last_step': step}, allow_val_change=True)

    wandb.run.log(log_data, commit=commit)


def get_default_argument_parser(description=None):
    parser = ArgumentParser(description=description)
    # makes sure to include the defaults

    # # dataset parameters
    parser.add_argument('--dataset', type=str, help='')
    parser.add_argument('--dataset.path', type=str, default=None, help='Path to dataset')
    parser.add_argument('--dataset.sigma', type=int, default=None,
                        help='Sigma used for gaussian blur for low res image')
    parser.add_argument('--dataset.kernel_size', type=int, default=None, help='Kernel size of gaussian blur')
    parser.add_argument('--dataset.rotation_angle', type=check_rotate_type, default=None,
                        help='Rotation of the tomogram in degree or "random".')
    parser.add_argument('--dataset.num_defects', type=int, default=None, help='')
    parser.add_argument('--dataset.index', type=int, default=None, help='Only run on selected index')
    # parser.add_argument('--dataset.num_data_points', type=int, default=None, help='Only run first {num_data_points} data points')
    # parser.add_argument('--dataset.shuffle', type=int, default=None, help='Shuffle dataset')
    # parser.add_argument('--dataset.seed', type=int, default=None, help='Seed generator')  # not used yet
    # parser.add_argument('--dataset.interpolation_mode', type=str, default='bilinear', choices=['bilinear', 'nearest'], help='Interpolation mode for rotation')
    # parser.add_argument('--dataset.prior', type=str, choices=['sinogram', 'lr'], default=None, help='')
    # parser.add_argument('--dataset.prior_angles', type=int, default=180, help='')

    # observation parameters
    # parser.add_argument('--sinogram.sigma', type=float, default=10, help='')

    # experiment parameters
    parser.add_argument('--experiment.num_iterations', type=int, default=4, help='active learning number of iterations')
    parser.add_argument('--experiment.theta', type=int, default=180,
                        help='number of angles available for active learning')
    parser.add_argument('--experiment.batch_size', type=int, default=1, help='active learning batch size')

    parser.add_argument('--wandb.disabled', action='store_true', help='Log to wandb')
    parser.add_argument('--wandb.entity', type=str, default='sdsc-chip', help='wandb user/organization')
    parser.add_argument('--wandb.skip_completed', action='store_true', help='Skip completed experiments')
    parser.add_argument('--wandb.debug', action='store_true', help='Tag as debug')
    parser.add_argument('--wandb.tag', type=str, default=None, help='add tag to wandb run')

    return parser


def overwrite_kwargs(config, kwargs):
    for k in kwargs.keys():
        _k = f'dataset.{k}'
        if _k in config and config[_k] is not None:
            print(f"Warning: Overriding default dataset value ({kwargs[k]}) for {k} with {config[_k]}.")
            kwargs[k] = config[_k]


def get_dataset(config):
    if config['dataset'] == 'superresolution':
        path = config['dataset.path']
        files = os.listdir(os.path.join(path, 'imgs_synthetic'))
        ds = SuperresolutionDS(files,
                               data_path=path,
                               sigma=config['dataset.sigma'],
                               num_defects=config['dataset.num_defects'],
                               rotate_angle=config['dataset.rotate_angle'],
                               interpolation_mode=config['dataset.interpolation_mode'])
        dataset = Subset(ds, range(round(0.9 * len(ds)), len(ds)))
        return dataset

    circle_filter = create_circle_filter(30, 512)
    lr_forward_function = lambda x: fourier_filtering(x, circle_filter)

    if config['dataset'] in ['tomogram_synthetic_10', 'tomogram_synthetic_gray_10']:
        kwargs = {
            'path': '/mydata/chip/shared/data/tomogram_synthetic.h5',
            'lr_forward_function': lr_forward_function,
            'rescale': None,
            'clip_range': None,
            'normalize_range': None,
            'rotation_angle': 30,
            'num_defects': None,
            'contrast': None,
            'gray_background': 'gray' in config['dataset'],
        }
        overwrite_kwargs(config, kwargs)
        dataset = TomogramDataset(**kwargs)
        dataset = Subset(dataset, range(round(0.95 * len(dataset)), len(dataset)))
        dataset = Subset(dataset, range(10))
        return dataset

    if config['dataset'] in ['tomogram_chip1_bw_10', 'tomogram_chip1_10']:
        kwargs = {
            'path': '/mydata/chip/shared/ra.psi.ch/p17299/data_for_SDSC/tomogram_delta.mat',
            'lr_forward_function': lr_forward_function,
            'rescale': 512,
            'clip_range': True,
            'normalize_range': True,
            'rotation_angle': 0,
            'num_defects': None,
            'contrast': 15 if 'bw' in config['dataset'] else None,
            'crop': (200, 200, 900),
        }
        overwrite_kwargs(config, kwargs)
        dataset = TomogramDataset(**kwargs)
        dataset = Subset(dataset, np.linspace(300, 400, 10, dtype=int))
        return dataset

    if config['dataset'] in ['tomogram_chip2_bw_10', 'tomogram_chip2_10']:
        kwargs = {
            'path': '/mydata/chip/shared/ra.psi.ch/p17299/data_for_SDSC_v0/tomogram_delta.mat',
            'lr_forward_function': lr_forward_function,
            'rescale': 512,
            'clip_range': True,
            'normalize_range': True,
            'rotation_angle': 30,
            'num_defects': None,
            'contrast': 15 if 'bw' in config['dataset'] else None,
            'crop': (100, 100, 900),
        }
        overwrite_kwargs(config, kwargs)
        dataset = TomogramDataset(**kwargs)
        dataset = Subset(dataset, np.linspace(50, 170, 10, dtype=int))
        return dataset

    if config['dataset'] in ['tomogram_chip3_paired_bw_10', 'tomogram_chip3_paired_10']:
        kwargs = {
            'path': '/mydata/chip/shared/ra.psi.ch/p17299/data_for_SDSC/lowres_vs_highres_aligned/HighResolutionTomogram.mat',
            'lr_path': '/mydata/chip/shared/ra.psi.ch/p17299/data_for_SDSC/lowres_vs_highres_aligned/LowResolutionTomogram.mat',
            'rescale': 512,
            'clip_range': True,
            'normalize_range': True,
            'rotation_angle': 30,
            'num_defects': None,
            'contrast': 15 if 'bw' in config['dataset'] else None,
            'crop': (80, 80, 1350),
        }
        overwrite_kwargs(config, kwargs)
        dataset = TomogramDataset(**kwargs)
        dataset = Subset(dataset, np.linspace(190, 450, 10, dtype=int))
        return dataset

    if config['dataset'] in ['tomogram_chip3_bw_10', 'tomogram_chip3_10']:
        kwargs = {
            'path': '/mydata/chip/shared/ra.psi.ch/p17299/data_for_SDSC/lowres_vs_highres_aligned/HighResolutionTomogram.mat',
            'lr_forward_function': lr_forward_function,
            'rescale': 512,
            'clip_range': True,
            'normalize_range': True,
            'rotation_angle': 30,
            'num_defects': None,
            'contrast': 15 if 'bw' in config['dataset'] else None,
            'crop': (80, 80, 1350),
        }
        overwrite_kwargs(config, kwargs)
        dataset = TomogramDataset(**kwargs)
        dataset = Subset(dataset, np.linspace(190, 450, 10, dtype=int))
        return dataset
