"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import datetime
import json
import logging
import os
import time

import torch
import torch.distributed as dist
import webdataset as wds
from lavis.common.dist_utils import download_cached_file, is_main_process, main_process
from lavis.common.registry import registry
from lavis.common.utils import is_url
from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split, prepare_sample
from lavis.runners.runner_base import RunnerBase
from torch.utils.data.dataset import ChainDataset

from lavis.common.logger import MetricLogger, SmoothedValue
from torchvision.utils import save_image


@registry.register_runner("runner_gene_image")
class RunnerGenerateImage(RunnerBase):
    """
    Generate captions for Df images.
    """

    def __init__(self, cfg, task, model, datasets, job_id):
        super().__init__(cfg, task, model, datasets, job_id)

    def set_model(self, model):
        self._model = model.to(self.device)

    def generate(self, save_path="", spec_ebds=None):
        # TODO: logger here not work
        metric_logger = MetricLogger(delimiter="  ")
        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4e}"))

        model = self.unwrap_dist_model(self.model)
        model.eval()

        split_name = "test"  # todo: replace mixed_mllmu with val split
        data_loader = self.dataloaders.get(split_name, None)
        self.task.before_evaluation(
            model=model,
            dataset=self.datasets[split_name],
        )
        results = self.task.generate(model,
                                     data_loader,
                                     cuda_enabled=True,
                                     spec_ebds=spec_ebds,
                                     )

        if save_path:
            json.dump(results, open(save_path, "w"))
            print(f"Saved {len(results)} virtual images, filename list saved to {save_path}.")

    def get_txt_spec_conds(self):
        # TODO: logger here not work
        metric_logger = MetricLogger(delimiter="  ")
        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4e}"))

        model = self.unwrap_dist_model(self.model)
        model.eval()

        split_name = "test"  # todo: replace mixed_mllmu with val split
        data_loader = self.dataloaders.get(split_name, None)
        self.task.before_evaluation(
            model=model,
            dataset=self.datasets[split_name],
        )
        results, id_order = self.task.get_txt_specs(model, data_loader, cuda_enabled=True)
        return results, id_order

    def load_checkpoint(self, ldm_ckpt=None, url_or_filename=None):
        """
        Resume from a checkpoint.
        """
        assert ldm_ckpt is not None or url_or_filename is not None, "No checkpoint is provided."
        if ldm_ckpt is not None:
            print(f"Loading model from {ldm_ckpt}")
            pl_sd = torch.load(ldm_ckpt, map_location="cpu")
            if "global_step" in pl_sd:
                print(f"Global Step: {pl_sd['global_step']}")
            sd = pl_sd["state_dict"]
            m, u = self.model.load_state_dict(sd, strict=False)
            if len(m) > 0:
                print("missing keys:")
                print(m)
            if len(u) > 0:
                print("unexpected keys:")
                print(u)
            # ldm.to(torch.device("cuda"))  # TODO: make these configable
            self.model.cuda()

        if url_or_filename is not None:
            if is_url(url_or_filename):
                cached_file = download_cached_file(
                    url_or_filename, check_hash=False, progress=True
                )
                checkpoint = torch.load(cached_file, map_location=self.device)
            elif os.path.isfile(url_or_filename):
                checkpoint = torch.load(url_or_filename, map_location=self.device)
            else:
                raise RuntimeError("checkpoint url or path is invalid")

            # state_dict = checkpoint["model"]
            state_dict = {}
            for n, p in checkpoint["model"].items():
                if n[:4] == "ldm.":
                    state_dict[n[4:]] = p

            self.unwrap_dist_model(self.model).load_state_dict(state_dict, strict=False)

            logging.info("Resume checkpoint from {}".format(url_or_filename))