"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import datetime
import json
import logging
import os
import time

import torch
import torch.distributed as dist
import webdataset as wds
from lavis.common.dist_utils import download_cached_file, is_main_process, main_process
from lavis.common.registry import registry
from lavis.common.utils import is_url
from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split, prepare_sample
from lavis.runners.runner_base import RunnerBase
from torch.utils.data.dataset import ChainDataset

from lavis.common.logger import MetricLogger, SmoothedValue
from torchvision.utils import save_image


@registry.register_runner("runner_gene_text")
class RunnerGenerateText(RunnerBase):
    """
    Generate captions for Df images.
    """

    def __init__(self, cfg, task, model, datasets, job_id):
        super().__init__(cfg, task, model, datasets, job_id)


    # def captioning(self):
    #
    #     # TODO: logger here not work
    #     metric_logger = MetricLogger(delimiter="  ")
    #     metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4e}"))
    #
    #     model = self.unwrap_dist_model(self.model)
    #     model.eval()
    #
    #     split_name = "train"
    #     data_loader = self.dataloaders.get(split_name, None)
    #     self.task.before_evaluation(
    #         model=model,
    #         dataset=self.datasets[split_name],
    #     )
    #     vir_anns, num_vir_imgs, loss, loss_dict = self.task.generate(model, data_loader)
    #     metric_logger.update(loss=loss.item())
    #     # TODO: make this configable
    #     json.dump(vir_anns, open("./results/vir_anns.txt", "w"))
    #     print(f"Saved {len(vir_anns)} virtual annoations and {num_vir_imgs} virtual images.")

    def captioning(self, save_path="", spec_conds=None):
        # TODO: logger here not work
        metric_logger = MetricLogger(delimiter="  ")
        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4e}"))

        model = self.unwrap_dist_model(self.model)
        model.eval()

        split_name = "train"  # todo: replace mixed_mllmu with val split
        data_loader = self.dataloaders.get(split_name, None)
        self.task.before_evaluation(
            model=model,
            dataset=self.datasets[split_name],
        )
        results = self.task.generate(model,
                                     data_loader,
                                     cuda_enabled=True,
                                     spec_conds=spec_conds,
                                     use_nucleus_sampling=True,
                                     num_beams=3,
                                     max_length=50,
                                     min_length=10,
                                     top_p=0.7,
                                     repetition_penalty=1.1,
                                     length_penalty=1.1,
                                     num_captions=5,
                                     temperature=1.2,
                                     )

        # print(results[0])
        if save_path:
            json.dump(results, open(save_path, "w"))
            print(f"Saved {len(results)} virtual texts to {save_path}.")

    def get_img_spec_conds(self):
        # TODO: logger here not work
        metric_logger = MetricLogger(delimiter="  ")
        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4e}"))

        model = self.unwrap_dist_model(self.model)
        model.eval()

        split_name = "train"  # todo: replace mixed_mllmu with val split
        data_loader = self.dataloaders.get(split_name, None)
        self.task.before_evaluation(
            model=model,
            dataset=self.datasets[split_name],
        )
        results, id_order = self.task.get_img_specs(model, data_loader, cuda_enabled=True)
        return results, id_order

    def load_checkpoint(self, url_or_filename):
        """
        Resume from a checkpoint.
        """
        if is_url(url_or_filename):
            cached_file = download_cached_file(
                url_or_filename, check_hash=False, progress=True
            )
            checkpoint = torch.load(cached_file, map_location=self.device)
        elif os.path.isfile(url_or_filename):
            checkpoint = torch.load(url_or_filename, map_location=self.device)
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        state_dict = checkpoint["model"]
        self.unwrap_dist_model(self.model).load_state_dict(state_dict, strict=False)


        logging.info("Resume checkpoint from {}".format(url_or_filename))
