import torch
import logging
import os
from fcclip import FCCLIP
from torch import nn

from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.modeling import META_ARCH_REGISTRY

from detectron2.utils.registry import Registry
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from detectron2.structures import ImageList
from detectron2.utils.memory import retry_if_cuda_oom
from detectron2.modeling.postprocessing import sem_seg_postprocess


from fcclip.modeling.transformer_decoder import MultiScaleMaskedTransformerDecoder
from fcclip.modeling.transformer_decoder.fcclip_transformer_decoder import get_classification_logits
from fcclip.modeling import FCCLIPHead
from fcclip.fcclip import VILD_PROMPT

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np

TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE")
TRANSFORMER_DECODER_REGISTRY.__doc__ = """
Registry for transformer module in MaskFormer.
"""


class PromptEmbedding(nn.Module):
    def __init__(self, num_queries, embedding_dim):
        super(PromptEmbedding, self).__init__()
        self.num_queries = num_queries
        self.embedding_freeze = nn.Embedding(num_queries, embedding_dim)
        self.embedding_warm = nn.Embedding(num_queries, embedding_dim)
        self.embedding_freeze.requires_grad_(False)

    def replace_twice(self, weight):
        with torch.no_grad():
            self.embedding_freeze.weight.copy_(weight)
            self.embedding_warm.weight.copy_(weight)

    def task_arithmetic(self, weight):
        tau = weight * (self.embedding_warm.weight - self.embedding_freeze.weight)
        return self.embedding_freeze.weight + tau

    def concat_weight(self):
        return torch.concat([self.embedding_freeze.weight, self.embedding_warm.weight], dim=0)

    def forward(self, x):
        return self.embedding(x)


@META_ARCH_REGISTRY.register()
class FCCLIPrompt(FCCLIP):
    @configurable
    def __init__(self, *args, **kwargs):
        self.additional_prompt = kwargs.pop('additional_prompt', None)
        self.task_arithmetic = kwargs.pop('task_arithmetic', None)
        self.task_arithmetic_lambda = kwargs.pop('task_arithmetic_lambda', None)
        add_num_queries = kwargs.pop('add_num_queries', None)
        freeze_param_names = kwargs.pop('freeze_param_names', None)
        hidden_dim = kwargs.pop('hidden_dim', None)
        self.analysis = kwargs.pop('analysis', False)
        self.num_classes = kwargs.pop('num_classes', None)
        self.all_train = kwargs.pop('all_train', False)
        self.output_dir = kwargs.pop('output_dir', None)
        self.l2_reg = kwargs.pop('l2_reg', False)

        train_metadata = kwargs.get('train_metadata', None)
        self.num_classes_names = train_metadata.thing_classes_orig
        super().__init__(*args, **kwargs)

        self.criterion.analysis = self.analysis

        if self.additional_prompt:
            self.replace_flag = True
            self.add_query_feat = PromptEmbedding(add_num_queries, hidden_dim)
            self.add_query_embed = PromptEmbedding(add_num_queries, hidden_dim)
            self.num_queries *= 2
        else:
            self.replace_flag = False
            self.add_query_feat = None
            self.add_query_embed = None

        self.loss_each_class_list = []
        self.loss_each_class = {
            "mask": torch.zeros(self.num_classes),
            "dice": torch.zeros(self.num_classes),
            "ce": torch.zeros(self.num_classes),
        }

        self.save_fig_path = os.path.join(self.output_dir, 'loss_graph')
        os.makedirs(self.save_fig_path, exist_ok=True)
        self.class_inst_count = torch.zeros(self.num_classes)
        self.iter_count, self.iter_period = 0, 500

        if not self.all_train:
            self.freeze_everything()
            self.tuning_only_prompt(freeze_param_names)
        else:
            if self.additional_prompt:
                self.freeze_some_param(['predictor.query_feat', 'predictor.query_embed'])

        if self.l2_reg:
            self.pretrained_head = [i.detach().clone() for i in self.sem_seg_head.parameters()]
        self.logger = logging.getLogger("detectron2.trainer")

    def freeze_everything(self):
        for param in self.sem_seg_head.parameters():
            param.requires_grad = False

    def tuning_only_prompt(self, freeze_param_names):
        for module_param_name, param in self.sem_seg_head.named_parameters():
            for name in freeze_param_names:
                if name in module_param_name:
                    param.requires_grad = True

    def freeze_some_param(self, freeze_param_names):
        for module_param_name, param in self.sem_seg_head.named_parameters():
            for name in freeze_param_names:
                if name in module_param_name:
                    param.requires_grad = False

    # For analysis
    def return_loss_each_class(self):
        return_dict = {}
        for k in self.loss_each_class.keys():
            return_dict[k] = self.loss_each_class[k] / (self.class_inst_count+1e-6)
        return return_dict

    def init_loss_each_class(self):
        self.loss_each_class = {
            "mask": torch.zeros(self.num_classes),
            "dice": torch.zeros(self.num_classes),
            "ce": torch.zeros(self.num_classes),
        }
        self.class_inst_count = torch.zeros(self.num_classes)

    def make_graph_each_class(self):
        # 각 class 별로 graph 생성
        x = np.arange(0, len(self.loss_each_class_list), 1) * self.iter_period
        for idx in range(self.num_classes):
            y_mask, y_class = [], []
            for ep in range(len(self.loss_each_class_list)):
                epoch_loss = self.loss_each_class_list[ep]
                y_mask.append((epoch_loss['mask'][idx]+epoch_loss['dice'][idx]).item() * 1000)
                y_class.append(epoch_loss['ce'][idx].item())

            plt.figure(figsize=(6, 4))
            plt.plot(x, y_mask, label='Mask Loss * 1000')
            plt.plot(x, y_class, label='Class Loss')

            # 그래프 제목과 라벨 설정
            plt.title(f'Class: {self.num_classes_names[idx]}, Iter: {self.iter_count}')
            plt.xlabel(f'iteration')
            plt.ylabel('loss_values')
            plt.legend()
            plt.grid(True)

            # 그래프 출력
            plt.savefig(os.path.join(self.save_fig_path, f'{self.num_classes_names[idx]}_{self.iter_count}.png'),
                        dpi=300, facecolor='white', edgecolor='black', orientation='portrait',
                        format='png', transparent=False, bbox_inches='tight', pad_inches=0.1)

            self.analysis = False

    @classmethod
    def from_config(cls, cfg):
        config_dict = super().from_config(cfg)

        config_dict['additional_prompt'] = cfg.MODEL.PROMPT_TUNING.ADDITIONAL_PROMPT
        config_dict['add_num_queries'] = cfg.MODEL.PROMPT_TUNING.NUM_QUERIES
        config_dict['freeze_param_names'] = cfg.MODEL.PROMPT_TUNING.FREEZE_PARAM_NAMES
        config_dict['hidden_dim'] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
        config_dict['task_arithmetic'] = cfg.MODEL.PROMPT_TUNING.TASK_ARITHMETIC
        config_dict['task_arithmetic_lambda'] = cfg.MODEL.PROMPT_TUNING.TASK_ARITHMETIC_LAMBDA
        config_dict['all_train'] = cfg.MODEL.PROMPT_TUNING.ALL_TRAIN

        # For Analysis
        config_dict['analysis'] = cfg.MODEL.PROMPT_TUNING.ANALYSIS
        config_dict['num_classes'] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        config_dict['output_dir'] = cfg.OUTPUT_DIR

        # For Regularization
        config_dict['l2_reg'] = cfg.MODEL.PROMPT_TUNING.L2_REG

        return config_dict

    def prepare_class_names_from_metadata(self, metadata, train_metadata):
        def split_labels(x):
            res = []
            for x_ in x:
                x_ = x_.replace(', ', ',')
                x_ = x_.split(',')  # there can be multiple synonyms for single class
                res.append(x_)
            return res

        # get text classifier
        try:
            class_names = split_labels(metadata.stuff_classes)  # it includes both thing and stuff
            train_class_names = split_labels(train_metadata.stuff_classes)
        except:
            # this could be for insseg, where only thing_classes are available
            class_names = split_labels(metadata.thing_classes)
            train_class_names = split_labels(train_metadata.thing_classes)
        train_class_names = {l for label in train_class_names for l in label}
        category_overlapping_list = []
        for test_class_names in class_names:
            is_overlapping = not set(train_class_names).isdisjoint(set(test_class_names))
            category_overlapping_list.append(is_overlapping)
        category_overlapping_mask = torch.tensor(
            category_overlapping_list, dtype=torch.long)

        def fill_all_templates_ensemble(x_=''):
            res = []
            for x in x_:
                for template in VILD_PROMPT:
                    res.append(template.format(x))
            return res, len(res) // len(VILD_PROMPT)

        num_templates = []
        templated_class_names = []
        for x in class_names:
            templated_classes, templated_classes_num = fill_all_templates_ensemble(x)
            templated_class_names += templated_classes
            num_templates.append(templated_classes_num)  # how many templates for current classes
        class_names = templated_class_names
        # print("text for classification:", class_names)
        return category_overlapping_mask, num_templates, class_names

    def l2_regularize(self):
        loss_output = torch.tensor(0., device=self.device)
        for param, pretrained_param in zip(self.sem_seg_head.parameters(), self.pretrained_head):
            loss_output = loss_output + torch.abs(param - pretrained_param.to(self.device)).mean()
        return loss_output / len(self.pretrained_head)

    def replace_param(self):
        self.add_query_feat.replace_twice(self.sem_seg_head.predictor.query_feat.weight)
        self.add_query_embed.replace_twice(self.sem_seg_head.predictor.query_embed.weight)

    def forward(self, batched_inputs):
        if self.replace_flag and self.additional_prompt and self.training:
            self.replace_flag = False
            self.replace_param()

        if self.replace_flag and self.task_arithmetic and not self.training:
            self.replace_flag = False
            self.sem_seg_head.predictor.query_feat.weight.copy_(
                self.add_query_feat.task_arithmetic(self.task_arithmetic_lambda))
            self.sem_seg_head.predictor.query_embed.weight.copy_(
                self.add_query_embed.task_arithmetic(self.task_arithmetic_lambda))
            # To use the original queries.
            self.num_queries = self.num_queries // 2
            self.additional_prompt = False

        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        images = ImageList.from_tensors(images, self.size_divisibility)

        features = self.backbone(images.tensor)
        text_classifier, num_templates = self.get_text_classifier()
        # Append void class weight
        text_classifier = torch.cat([text_classifier, F.normalize(self.void_embedding.weight, dim=-1)], dim=0)
        features['text_classifier'] = text_classifier
        features['num_templates'] = num_templates
        outputs = self.sem_seg_head(features,
                                    add_query_feat=self.add_query_feat,
                                    add_query_embed=self.add_query_embed,
                                    additional_prompt=self.additional_prompt)

        if self.training:
            # mask classification target
            if "instances" in batched_inputs[0]:
                gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
                targets = self.prepare_targets(gt_instances, images)
            else:
                targets = None

            # bipartite matching-based loss
            losses = self.criterion(outputs, targets)

            if self.analysis:
                with torch.no_grad():
                    target_classes_o = losses['target_classes_o'].cpu()
                    target_unique, target_counts = torch.unique(target_classes_o, return_counts=True)
                    self.class_inst_count[target_unique] += target_counts.cpu()
                    for lec_k in self.loss_each_class.keys():
                        self.loss_each_class[lec_k][target_classes_o] += \
                            losses[f'loss_{lec_k}_inst'].cpu()
                self.iter_count += 1
                if self.iter_count % self.iter_period == 0:
                    loss_each_class = self.return_loss_each_class()
                    # str_func = lambda x: f"{x:.6f}"
                    # self.logger.info(f"Iteration: {self.iter_count}")
                    # for lec_k in loss_each_class.keys():
                    #     llist = list(map(str_func, loss_each_class[lec_k].tolist()))
                    #     self.logger.info(f"{lec_k}: {' '.join(llist)}")
                    self.loss_each_class_list.append(loss_each_class)
                    self.init_loss_each_class()

                if self.iter_count % (self.iter_period * 10) == 0:
                    self.make_graph_each_class()

            for k in list(losses.keys()):
                if k in self.criterion.weight_dict:
                    losses[k] *= self.criterion.weight_dict[k]
                else:
                    # remove this loss if not specified in `weight_dict`
                    losses.pop(k)

            if self.l2_reg:
                losses['loss_l2_reg'] = self.l2_regularize()

            return losses
        else:
            mask_cls_results = outputs["pred_logits"]
            mask_pred_results = outputs["pred_masks"]

            # We ensemble the pred logits of in-vocab and out-vocab
            clip_feature = features["clip_vis_dense"]
            mask_for_pooling = F.interpolate(mask_pred_results, size=clip_feature.shape[-2:],
                                             mode='bilinear', align_corners=False)
            if "convnext" in self.backbone.model_name.lower():
                pooled_clip_feature = self.mask_pooling(clip_feature, mask_for_pooling)
                pooled_clip_feature = self.backbone.visual_prediction_forward(pooled_clip_feature)
            elif "rn" in self.backbone.model_name.lower():
                pooled_clip_feature = self.backbone.visual_prediction_forward(clip_feature, mask_for_pooling)
            else:
                raise NotImplementedError

            out_vocab_cls_results = get_classification_logits(pooled_clip_feature, text_classifier,
                                                              self.backbone.clip_model.logit_scale, num_templates)
            in_vocab_cls_results = mask_cls_results[..., :-1]  # remove void
            out_vocab_cls_results = out_vocab_cls_results[..., :-1]  # remove void

            # Reference: https://github.com/NVlabs/ODISE/blob/main/odise/modeling/meta_arch/odise.py#L1506
            out_vocab_cls_probs = out_vocab_cls_results.softmax(-1)
            in_vocab_cls_results = in_vocab_cls_results.softmax(-1)
            category_overlapping_mask = self.category_overlapping_mask.to(self.device)

            if self.ensemble_on_valid_mask:
                # Only include out_vocab cls results on masks with valid pixels
                # We empirically find that this is important to obtain reasonable AP/mIOU score with ResNet CLIP models
                valid_masking = (mask_for_pooling > 0).to(mask_for_pooling).sum(-1).sum(-1) > 0
                valid_masking = valid_masking.to(in_vocab_cls_results.dtype).unsqueeze(-1)
                alpha = torch.ones_like(in_vocab_cls_results) * self.geometric_ensemble_alpha
                beta = torch.ones_like(in_vocab_cls_results) * self.geometric_ensemble_beta
                alpha = alpha * valid_masking
                beta = beta * valid_masking
            else:
                alpha = self.geometric_ensemble_alpha
                beta = self.geometric_ensemble_beta

            cls_logits_seen = (
                    (in_vocab_cls_results ** (1 - alpha) * out_vocab_cls_probs ** alpha).log()
                    * category_overlapping_mask
            )
            cls_logits_unseen = (
                    (in_vocab_cls_results ** (1 - beta) * out_vocab_cls_probs ** beta).log()
                    * (1 - category_overlapping_mask)
            )
            cls_results = cls_logits_seen + cls_logits_unseen

            # This is used to filtering void predictions.
            is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:]
            mask_cls_probs = torch.cat([
                cls_results.softmax(-1) * (1.0 - is_void_prob),
                is_void_prob], dim=-1)
            mask_cls_results = torch.log(mask_cls_probs + 1e-8)

            # upsample masks
            mask_pred_results = F.interpolate(
                mask_pred_results,
                size=(images.tensor.shape[-2], images.tensor.shape[-1]),
                mode="bilinear",
                align_corners=False,
            )

            del outputs

            processed_results = []
            for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
                    mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
            ):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                processed_results.append({})

                if self.sem_seg_postprocess_before_inference:
                    mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
                        mask_pred_result, image_size, height, width
                    )
                    mask_cls_result = mask_cls_result.to(mask_pred_result)

                # semantic segmentation inference
                if self.semantic_on:
                    r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
                    if not self.sem_seg_postprocess_before_inference:
                        r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
                    processed_results[-1]["sem_seg"] = r

                # panoptic segmentation inference
                if self.panoptic_on:
                    panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
                    processed_results[-1]["panoptic_seg"] = panoptic_r

                # instance segmentation inference
                if self.instance_on:
                    instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result)
                    processed_results[-1]["instances"] = instance_r

            return processed_results


@SEM_SEG_HEADS_REGISTRY.register()
class FCCLIPHeadPrompt(FCCLIPHead):
    def forward(self, features, mask=None, **kwargs):
        return self.layers(features, mask, **kwargs)

    def layers(self, features, mask=None, **kwargs):
        mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
        if self.transformer_in_feature == "multi_scale_pixel_decoder":
            predictions = self.predictor(multi_scale_features, mask_features, mask,
                                         text_classifier=features["text_classifier"],
                                         num_templates=features["num_templates"], **kwargs)
        else:
            raise NotImplementedError
        return predictions