import logging
import os
from functools import partial

import torch
from torch.utils.tensorboard import SummaryWriter

from .EvaluationHandling import evaluate_models_on_datasets, load_evaluation_results, \
    save_evaluation_results, filter_result
from .TrainHandling import train_model
from ..data_handling.EvalDatasetFactory import create_eval_dataset
from ..data_handling.SupportedDatasets import SupportedEvalDatasets
from ..data_handling.datasets.SelfCollageDataset import SelfCollageDataset
from ..models.ModelFactory import create_model
from ..util import transforms, Constants
from ..util.Constants import TB_SUB_DIR
from ..util.losses import DensityMSELoss

LOGGER = logging.getLogger()


def train_and_evaluate(args, base_data_dir, model_dir, plot_dir, model_dict, device):
    weights_dir = base_data_dir

    model = create_model(args, weights_dir=weights_dir)
    criterion = DensityMSELoss(args.density_scaling, args.density_loss_mask_prob,
                               args.density_loss_use_independent_masks, args.density_loss_keep_object_pixels,
                               args.density_loss_keep_all_object_pixels,
                               args.density_loss_penalise_wrong_cluster_objects,
                               args.density_loss_wrong_cluster_penality, args.img_size)

    LOGGER.info(f'Self-supervised training of {model}')

    if args.normalise:
        normalise_transform = model.backbone.normalise_transform
    else:
        normalise_transform = None
    model = model.to(device)
    model_path = os.path.join(model_dir, f'{model}.pt')
    args_path = os.path.join(model_dir, Constants.model_args_file_name)
    summary_writer = SummaryWriter(os.path.join(model_dir, TB_SUB_DIR))

    def _test_model(model, logging=True):
        test_trans = [
            transforms.ToTensor(),
            transforms.Resize(size=(args.img_size, args.img_size))
        ]
        if normalise_transform is not None:
            test_trans += [normalise_transform]
        test_transforms = transforms.Compose(test_trans)
        results = evaluate_models_on_datasets(args, base_data_dir, args.eval_dataset, model, criterion, device,
                                              test_transforms=test_transforms, batch_size=args.batch_size,
                                              num_count_classes=args.num_count_classes, disable_tqdm=args.disable_tqdm or not logging,
                                              num_workers=args.num_workers, logging=logging)
        return results

    # return saved model if it was already trained
    if os.path.isfile(model_path):
        LOGGER.info(f'Returning saved trained {model}')
        with open(model_path, 'rb') as f:
            state_dict = torch.load(f, map_location=device)
        model.load_state_dict(state_dict)
    else:
        with open(args_path, 'wb') as f:
            LOGGER.info(f'Saving arguments')
            torch.save(args, f)
        # training
        patch_size = model.patch_size
        if patch_size is None:
            patch_size = 16
        training_dataset = SelfCollageDataset(args, base_data_dir, normalise_transform=normalise_transform,
                                              device=device, patch_size=patch_size, plot_dir=plot_dir,
                                              weights_dir=weights_dir)
        # measure indexing time of training dataset
        training_dataset.measure_dataset_indexing_time(disable_tqdm=args.disable_tqdm)


        _test_model_training = partial(_test_model, logging=False)

        if args.visualise_test:
            test_trans = [
                transforms.ToTensor(),
                transforms.Resize(size=(args.img_size, args.img_size))
            ]
            if normalise_transform is not None:
                test_trans += [normalise_transform]
            test_transforms = transforms.Compose(test_trans)
            test_dataset_small = create_eval_dataset(SupportedEvalDatasets.FSC147_low, base_data_dir, transform=test_transforms,
                                               disable_tqdm=True, use_reference_crops=True,
                                               reference_crop_size=args.reference_crop_size,
                                               density_scaling=args.density_scaling)
            test_samples_small = [test_dataset_small[i] for i in [1, 5, 6]]

            visualise_test_samples = {'test': test_samples_small}
        else:
            visualise_test_samples = None

        model = train_model(args, model_path, model, training_dataset, criterion=criterion,
                            num_count_classes=args.num_count_classes, summary_writer=summary_writer, device=device,
                            eval_func=_test_model_training, visualise_test_samples=visualise_test_samples)

        with open(model_path, 'wb') as f:
            LOGGER.info(f'Saving trained {model}')
            torch.save(model.state_dict(), f)

    results = load_evaluation_results(model_dir)
    if results is not None:
        print_results = {dataset: filter_result(result) for dataset, result in results.items()}
        LOGGER.info(print_results)
    else:
        results = _test_model(model)

        save_evaluation_results(results, [model_dir], [summary_writer], [model_dict])

    summary_writer.flush()
    return model
