import json
import logging
import os

import numpy as np
import torch
from lavis.common.dist_utils import is_main_process
from lavis.common.registry import registry
from lavis.tasks.base_task import BaseTask
from torch.cuda.amp import autocast
from lavis.common.logger import MetricLogger
from lavis.datasets.data_utils import prepare_sample
from lavis.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized

from torchvision.utils import save_image

@registry.register_task("img_spec_cdt")
class ImgSpecCdtTask(BaseTask):
    @autocast(True)
    def train_step(self, model, samples):
        return model(samples)

    def train_epoch(
            self,
            epoch,
            model,
            data_loader,
            optimizer,
            lr_scheduler,
            scaler=None,
            cuda_enabled=False,
            log_freq=50,
            accum_grad_iters=1,
    ):
        model.train()
        model.ldm.eval()
        return self._train_inner_loop(
            epoch=epoch,
            iters_per_epoch=len(data_loader),
            model=model,
            data_loader=data_loader,
            optimizer=optimizer,
            scaler=scaler,
            lr_scheduler=lr_scheduler,
            log_freq=log_freq,
            cuda_enabled=cuda_enabled,
            accum_grad_iters=accum_grad_iters,
        )

    def evaluation(self, model, data_loader, cuda_enabled=True, **kwargs):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Evaluation"
        print_freq = 10

        results = []

        for i, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)

            eval_output = self.valid_step(model=model, samples=samples)
            return {"agg_metrics": 1, "eval_output": eval_output}

        return results

    @autocast(True)
    def valid_step(self, model, samples, **kwargs):
        result = model.gene_val_imgs(samples)

        return result

    def after_evaluation(self, **kwargs):
        outs = kwargs
        outs["agg_metrics"] = 1
        return outs

    def generate(self, model, data_loader, cuda_enabled=True):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Generation:"
        print_freq = 50

        assert len(data_loader) > 0

        vir_anns = []
        cnt_vir_imgs = 0

        loader_id_to_image_id = {loader_id: image_id for image_id, loader_id in
                                 data_loader._dataloader.loader.dataset.datasets[0].img_ids.items()}

        with open("path to annotation.json", "r") as f:
            ori_anns = json.load(f)
        id2filemap = {ann['image_id']: ann['image'] for ann in ori_anns}

        for i, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
            if i == len(data_loader):
                break
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)

            vir_a, num_vir_image, loss, loss_dict = self.valid_step(model=model, samples=samples, id2filemap=id2filemap,
                                                                    ldr2imgmap=loader_id_to_image_id)
            vir_anns += vir_a

            cnt_vir_imgs += num_vir_image

        return vir_anns, cnt_vir_imgs, loss, loss_dict


    def gene_vir_data(self, model, samples, **kwargs):
        vir_annotations = []

        # run_cfg = slf.cfg.run_cfg
        outputs = model(samples)

        out_imgs = outputs.virtual_images
        out_texts = outputs.virtual_texts

        ldr2id = kwargs["ldr2imgmap"]
        id2file = kwargs["id2filemap"]
        img_ids = [ldr2id[int(ldr_id)] for ldr_id in samples["image_id"]]
        texts = samples["text_input"]
        for cap, vir_cap, img_id, vir_img in zip(texts, out_texts, img_ids, out_imgs):
            img_name = id2file[int(img_id)]
            vir_img_name = f"vir_{img_name}"
            vir_annotations.append({"image": img_name, "caption": vir_cap, "image_id": int(img_id)})
            vir_annotations.append({"image": vir_img_name, "caption": cap, "image_id": int(img_id)+100000})

            save_image(vir_img, f"./results/spec_imgs/{vir_img_name}")

        return vir_annotations, len(img_ids), outputs.loss, outputs.loss_dict