"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import datetime
import json
import logging
import time
import os
import torch

from lavis.common.dist_utils import (
    is_main_process,
    main_process,
)
from lavis.common.registry import registry
from lavis.runners.runner_base import RunnerBase


@registry.register_runner("runner_cons_unlearn")
class RunnerConsUnlearn(RunnerBase):
    """
    A runner class to train and evaluate a model given a task and datasets.

    The runner uses pytorch distributed data parallel by default. Future release
    will support other distributed frameworks.
    """
    def train(self):
        start_time = time.time()
        best_agg_metric = 1000
        best_epoch = 0

        self.log_config()

        # resume from checkpoint if specified
        if not self.evaluate_only and self.resume_ckpt_path is not None:
            self._load_checkpoint(self.resume_ckpt_path)
            ## Change start
            if self.resume_ckpt_path is not None:
                if hasattr(self.model, "use_distill") and self.model.use_distill:
                    self.unwrap_dist_model(self.model.copy_params())
            ## Change end

        detailed_logs = {}

        for cur_epoch in range(self.start_epoch, self.max_epoch):
            # training phase
            if not self.evaluate_only:
                logging.info("Start training")
                train_stats, detailed_log = self.train_epoch(cur_epoch)
                self.log_stats(split_name="train", stats=train_stats)
                detailed_logs[cur_epoch] = detailed_log
                self.log_stats(detailed_logs, is_train=True)
            else:
                break

            # evaluation phase
            val_freq = self.config.run_cfg.val_freq
            if len(self.valid_splits) > 0 and (cur_epoch+1) % val_freq == 0:
                retrain_metric_df = 0  # TODO: make these configable or obtained by evaluation on retrain model
                retrain_metric_dr = 100
                for split_name in self.valid_splits:
                    logging.info("Evaluating on {}.".format(split_name))

                    val_log = self.eval_epoch(
                        split_name=split_name, cur_epoch=cur_epoch
                    )
                    if val_log is not None:
                        if is_main_process():
                            df_metric = abs(retrain_metric_df - val_log["df_metrics"]["r_mean"])
                            dr_metric = retrain_metric_dr - val_log["dr_metrics"]["r_mean"]
                            assert dr_metric > 0
                            assert df_metric >= 0
                            total_metrics = df_metric + dr_metric
                            # if total_metrics <= best_agg_metric:
                            if cur_epoch == 19:
                                best_epoch, best_agg_metric = cur_epoch, total_metrics
                                self._save_checkpoint(cur_epoch, is_best=True)
                            val_log.update({"best_epoch": best_epoch})
                            self.log_stats(val_log, split_name)

            # else:
            #     # if no validation split is provided, we just save the checkpoint at the end of each epoch.
            #     if not self.evaluate_only:
            #         self._save_checkpoint(cur_epoch, is_best=False)

            # dist.barrier()  # TODO: temp adjust for non-distribute training
        if not self.evaluate_only:
            detailed_log_save_path = os.path.join(
                self.output_dir,
                f"detailed_log.json"
            )
            json.dump(detailed_logs, open(detailed_log_save_path, "w"))

        # testing phase
        test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
        self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        logging.info("Training time {}".format(total_time_str))

    def evaluate(self, cur_epoch="best", skip_reload=False):
        test_logs = dict()

        if len(self.test_splits) > 0:
            for split_name in self.test_splits:
                test_logs[split_name] = self.eval_epoch(
                    split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
                )

            return test_logs

    @torch.no_grad()
    def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
        """
        Evaluate the model on a given split.

        Args:
            split_name (str): name of the split to evaluate on.
            cur_epoch (int): current epoch.
            skip_reload_best (bool): whether to skip reloading the best checkpoint.
                During training, we will reload the best checkpoint for validation.
                During testing, we will use provided weights and skip reloading the best checkpoint .
        """
        data_loader = self.dataloaders.get(split_name, None)
        assert data_loader, "data_loader for split {} is None.".format(split_name)

        # TODO In validation, you need to compute loss as well as metrics
        # TODO consider moving to model.before_evaluation()
        model = self.unwrap_dist_model(self.model)
        if not skip_reload and cur_epoch == "best":
            model = self._reload_best_model(model)
            logging.info('------------------------------ Reload best checkpoint ------------------------------')
        model.eval()

        self.task.before_evaluation(
            model=model,
            dataset=self.datasets[split_name],
        )
        # results = self.task.evaluation(model, data_loader, **{'split_name': split_name})

        if cur_epoch == "best" or cur_epoch == self.max_epoch-1:
            self.task.collect_and_save_logits(
                model,
                data_loader,
                output_path=self.output_dir,
                device="cuda",
            )

        if results is not None:
            return self.task.after_evaluation(
                val_result=results,
                split_name=split_name,
                epoch=cur_epoch,
            )

    @main_process
    def log_stats(self, stats, split_name=None, is_train=False):
        if is_train:
            with open(os.path.join(self.output_dir, "train_stat.txt"), "w") as f:
                f.write(json.dumps(stats) + "\n")
        else:
            assert split_name is not None
            if isinstance(stats, dict):
                log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
                with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
                    f.write(json.dumps(log_stats) + "\n")
            elif isinstance(stats, list):
                pass
