import json
import logging
import os

import numpy as np
import torch
from lavis.common.dist_utils import is_main_process
from lavis.common.registry import registry
from lavis.tasks.base_task import BaseTask
from torch.cuda.amp import autocast
from lavis.common.logger import MetricLogger
from lavis.datasets.data_utils import prepare_sample
from lavis.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized

from torchvision.utils import save_image

@registry.register_task("txt_spec_cdt")
class TxtSpecCdtTask(BaseTask):
    @autocast(True)  # TODO: fast adjust for use_fp16
    def train_step(self, model, samples):
        return model(samples)

    def train_epoch(
            self,
            epoch,
            model,
            data_loader,
            optimizer,
            lr_scheduler,
            scaler=None,
            cuda_enabled=False,
            log_freq=50,
            accum_grad_iters=1,
    ):
        model.train()
        model.ldm.eval()
        return self._train_inner_loop(
            epoch=epoch,
            iters_per_epoch=len(data_loader),
            model=model,
            data_loader=data_loader,
            optimizer=optimizer,
            scaler=scaler,
            lr_scheduler=lr_scheduler,
            log_freq=log_freq,
            cuda_enabled=cuda_enabled,
            accum_grad_iters=accum_grad_iters,
        )

    def evaluation(self, model, data_loader, cuda_enabled=True, **kwargs):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Evaluation"
        # TODO make it configurable
        print_freq = 10

        results = []

        for i, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
            # breakpoint()
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)

            eval_output = self.valid_step(model=model, samples=samples)
            return {"agg_metrics": 1, "eval_output": eval_output}

        return results

    @autocast(True)  # TODO: fast adjust for use_fp16
    def valid_step(self, model, samples, **kwargs):
        result = model.gene_val_txts(samples)

        return result

    def after_evaluation(self, **kwargs):
        outs = kwargs
        outs["agg_metrics"] = 1
        return outs