import torch
import torch.nn as nn
import sys

from prompts.imagenet_template import *

from mmseg.models.segmentors import BaseSegmentor
from mmseg.models.data_preprocessor import SegDataPreProcessor
from mmengine.structures import PixelData
from mmseg.registry import MODELS

import torch.nn.functional as F
from clip import tokenizer, create_model_and_transforms, create_model
from simfeatup_dev.upsamplers import get_upsampler
import numpy as np
import torchvision.transforms as transforms
from utils import *
import mmcv
from PIL import Image
from glob import glob
import os
import pickle as pkl
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

@MODELS.register_module()
class AlignCLIP(BaseSegmentor):
    def __init__(self,
                clip_type,
                clip_name,
                vfm_type,
                vit_name,
                model_type,
                name_path,
                device=torch.device('cuda'),
                ignore_residual=True,
                prob_thd=0.0,
                logit_scale=40,
                slide_stride=112,
                slide_crop=224,
                cls_token_lambda=0,
                bg_idx=0,
                feature_up=True,
                feature_up_cfg=dict(
                    model_name='jbu_one',
                    model_path='your/model/path'),
                aug_text=True,
                logits_lambda=0.1,
                visual_query_lambda=0.1,
                cluster_num=3):
        data_preprocessor = SegDataPreProcessor(
            bgr_to_rgb=True)
        super().__init__(data_preprocessor=data_preprocessor)

        self.device = device
        self.aug_text = aug_text

        if clip_type == "CLIP":
            if clip_name == "ViT-B/16":
                self.clip = create_model('ViT-B/16', pretrained='openai', device=device)
                self.clip_transform = transforms.Compose([
                    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
                ])
                self.clip_patch_size = self.clip.visual.patch_size[0]
            elif clip_name == "ViT-L/14":
                self.clip = create_model('ViT-L/14', pretrained='openai', device=device)
                self.clip_transform = transforms.Compose([
                    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
                ])
                self.clip_patch_size = self.clip.visual.patch_size[0]

        
        if vfm_type == "DINO":
            if vit_name == "ViT-B/16":
                self.vfm = torch.hub.load("facebookresearch/dino:main", "dino_vitb16").to(device)
                
                self.vfm_patch_size = self.vfm.patch_embed.patch_size if type(self.vfm.patch_embed.patch_size) == int else self.vfm.patch_embed.patch_size[0]
            else:
                self.vfm = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
                self.vfm_patch_size = self.vfm.patch_embed.patch_size if type(self.vfm.patch_embed.patch_size) == int else self.vfm.patch_embed.patch_size[0]
        
            self.vfm_transform = transforms.Compose([
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
        elif vfm_type == "SAM":
            if vit_name == "Vit-B/16":
                sam = sam_model_registry["vit_b"](checkpoint="models/sam_vit_b_01ec64.pth")
                sam.to(device=device)
                self.mask_generator = SamAutomaticMaskGenerator(sam)

        self.tokenizer = tokenizer.tokenize
        self.clip.eval()
        self.vfm.eval()


        self.clip_type = clip_type
        self.vfm_type = vfm_type
        self.model_type = model_type
        self.feature_up = feature_up
        self.logits_lambda = logits_lambda
        self.visual_query_lambda = visual_query_lambda
        self.cluster_num = cluster_num
        self.cls_token_lambda = cls_token_lambda
        self.output_cls_token = cls_token_lambda != 0
        self.bg_idx = bg_idx

        self.query_words, self.query_idx = get_cls_idx(name_path)
        self.num_queries = len(self.query_words)
        self.num_classes = max(self.query_idx) + 1
        self.query_idx = torch.Tensor(self.query_idx).to(torch.int64).to(device)


        query_features = []
        template = remote_template if self.aug_text else openai_imagenet_template
        with torch.no_grad():
            for qw in self.query_words:
                query = self.tokenizer([temp(qw) for temp in template]).to(device)
                feature = self.clip.encode_text(query)
                feature /= feature.norm(dim=-1, keepdim=True)
                feature = feature.mean(dim=0)
                feature /= feature.norm()
                query_features.append(feature.unsqueeze(0))
        self.query_features = torch.cat(query_features, dim=0)

        if self.feature_up:
            query_features = []
            template = openai_imagenet_template
            with torch.no_grad():
                for qw in self.query_words:
                    query = self.tokenizer([temp(qw) for temp in template]).to(device)
                    feature = self.clip.encode_text(query)
                    feature /= feature.norm(dim=-1, keepdim=True)
                    feature = feature.mean(dim=0)
                    feature /= feature.norm()
                    query_features.append(feature.unsqueeze(0))
            self.feature_up_query_features = torch.cat(query_features, dim=0)


        self.dtype = self.query_features.dtype
        self.ignore_residual = ignore_residual
        self.logit_scale = logit_scale
        self.prob_thd = prob_thd
        self.slide_stride = slide_stride
        self.slide_crop = slide_crop

        if feature_up:
            self.feat_dim = self.query_features.shape[-1]
            self.upsampler = get_upsampler(feature_up_cfg['model_name'], self.feat_dim).cuda()
            ckpt = torch.load(feature_up_cfg['model_path'])['state_dict']
            weights_dict = {k[10:]: v for k, v in ckpt.items()}
            self.upsampler.load_state_dict(weights_dict, strict=True)

    def get_vfm_feature(self, img):
        if self.vfm_type == "DINO":
            vfm_feature, vfm_attn = self.vfm.get_intermediate_layers_attn_weights(
            img, n=len(self.vfm.blocks))

            vfm_feature = torch.cat(vfm_feature, dim=0)
            vfm_attn = torch.cat(vfm_attn, dim=0)
            vfm_attn = vfm_attn[-1, 1:, 1:]
            vfm_feature = vfm_feature[-1, 1:]
            vfm_feature /= vfm_feature.norm(dim=-1, keepdim=True)
            
            return vfm_feature, vfm_attn
        elif self.vfm_type == "SAM":
            self.mask_generator.predictor.set_image(img.detach().cpu().numpy()[0].transpose(1,2,0).astype(np.uint8))
            vfm_feature = self.mask_generator.predictor.get_image_embedding()

            orig_h, orig_w = img.shape[-2:]
            scale = 1024 / max(orig_h, orig_w)
            scaled_h, scaled_w = int(orig_h * scale), int(orig_w * scale)
            valid_h = int(np.ceil(scaled_h / self.vfm_patch_size))
            valid_w = int(np.ceil(scaled_w / self.vfm_patch_size))
            vfm_feature = vfm_feature[:, :, :valid_h, :valid_w] # ignore padding

            downsample_size = (img.shape[-2] // self.vfm_patch_size, img.shape[-1] // self.vfm_patch_size)
            vfm_feature = F.interpolate(vfm_feature, size=downsample_size, mode='bilinear', align_corners=False)
            vfm_feature = vfm_feature[0].view(256, -1).permute(1, 0)
            vfm_feature /= vfm_feature.norm(dim=-1, keepdim=True)

            vfm_attn = vfm_feature @ vfm_feature.T
            return vfm_feature, vfm_attn

    def forward_feature(self, img, vfm_img, sam_img, logit_size=None):
        if img.shape[1] != 224 or img.shape[2] != 224:
            img = F.interpolate(img, size=(224, 224), mode='bilinear', align_corners=False)
            vfm_img = F.interpolate(vfm_img, size=(224, 224), mode='bilinear', align_corners=False)
            sam_img = F.interpolate(sam_img, size=(224, 224), mode='bilinear', align_corners=False)


        if self.vfm_type == "DINO":
            vfm_feature, vfm_attn = self.get_vfm_feature(vfm_img)
        elif self.vfm_type == "SAM":
            vfm_feature, vfm_attn = self.get_vfm_feature(vfm_img)
        
        _, labels_map = perform_clustering(vfm_feature, n_clusters=self.cluster_num)


        image_features, global_feature = self.clip.encode_image(img, model_type=self.model_type, ignore_residual=self.ignore_residual)
        
        if self.feature_up:
            feature_w, feature_h = img[0].shape[-2] // self.clip_patch_size, img[0].shape[-1] // self.clip_patch_size
            image_w, image_h = img[0].shape[-2], img[0].shape[-1]
            feature_up_image_features = image_features.permute(0, 2, 1).view(1, self.feat_dim, feature_w, feature_h)
            with torch.cuda.amp.autocast():
                feature_up_image_features = self.upsampler(feature_up_image_features, img)
            feature_up_image_features = feature_up_image_features.view(1, self.feat_dim, image_w * image_h).permute(0, 2, 1)
            feature_up_image_features /= feature_up_image_features.norm(dim=-1, keepdim=True)  # [1,196,512]
            
            feature_up_logits = feature_up_image_features @ self.feature_up_query_features.T

            if self.output_cls_token:
                global_feature /= global_feature.norm(dim=-1, keepdim=True)
                cls_logits = global_feature @ self.feature_up_query_features.T
                feature_up_logits = feature_up_logits + cls_logits * self.cls_token_lambda

        image_features /= image_features.norm(dim=-1, keepdim=True)


        sim = (image_features @ self.query_features.T).permute(0, 2, 1).softmax(dim=-1)
        _, index = sim.topk(1, dim=-1)
        sim_index = torch.gather(image_features.unsqueeze(0).repeat(1,self.num_queries,1,1), dim=2, index=index.unsqueeze(-1).repeat(1,1,1,image_features.shape[-1]))
        visial_query_features = sim_index.mean(dim=2).clone().squeeze()
        mix_query_features = self.visual_query_lambda * visial_query_features + (1 - self.visual_query_lambda) * self.query_features
        mix_query_features /= mix_query_features.norm(dim=-1, keepdim=True)
    
        logits = image_features @ mix_query_features.T


        final_attn = torch.zeros((vfm_attn.shape[0], vfm_attn.shape[1])).to(vfm_attn.device)
        for label_id in np.unique(labels_map):
            mask = (labels_map == label_id).reshape(vfm_attn.shape[0])
            mask = torch.tensor(mask).to(vfm_attn.device)
            final_attn[mask] = vfm_attn[mask, :] * mask[None, ...]
        final_attn = final_attn.to(image_features.dtype)
        logits = propagate_aff_cam_with_bkg(logits, aff=final_attn.unsqueeze(0))

        if self.feature_up:
            logits = logits.permute(0, 2, 1).reshape(logits.shape[0], logits.shape[2], img[0].shape[-2] // self.clip_patch_size, img[0].shape[-1] // self.clip_patch_size)
            logits = F.interpolate(logits, size=img.shape[-2:], mode='bilinear').reshape(logits.shape[0], logits.shape[1],-1).permute(0, 2, 1)

            logits = self.logits_lambda * logits + (1 - self.logits_lambda) * feature_up_logits


        if self.feature_up:
            w, h = img[0].shape[-2], img[0].shape[-1]
        else:
            w, h = img[0].shape[-2] // self.clip_patch_size, img[0].shape[-1] // self.clip_patch_size
        out_dim = logits.shape[-1]
        logits = logits.permute(0, 2, 1).reshape(-1, out_dim, w, h)

        if logit_size is None:
            logits = F.interpolate(logits, size=img.shape[-2:], mode='bilinear')
        else:
            logits = F.interpolate(logits, size=logit_size, mode='bilinear')

        return logits


    def forward_slide(self, clip_img, vfm_img, sam_img, img_metas, stride=112, crop_size=224):
        """Inference by sliding-window with overlap.
        If h_crop > h_img or w_crop > w_img, the small patch will be used to
        decode without padding.
        """
        if type(stride) == int:
            stride = (stride, stride)
        if type(crop_size) == int:
            crop_size = (crop_size, crop_size)


        h_stride, w_stride = stride
        h_crop, w_crop = crop_size
        batch_size, _, h_img, w_img = clip_img.shape
        out_channels = self.num_queries
        h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
        w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
        preds = clip_img.new_zeros((batch_size, out_channels, h_img, w_img))
        count_mat = clip_img.new_zeros((batch_size, 1, h_img, w_img))
        crop_id = 0

        gt_img = np.array(Image.open(img_metas[0]["seg_map_path"]))
        gt_img = mmcv.imrescale(gt_img, scale=(448,448), interpolation="nearest", backend="cv2")
        for h_idx in range(h_grids):
            for w_idx in range(w_grids):
                y1 = h_idx * h_stride
                x1 = w_idx * w_stride
                y2 = min(y1 + h_crop, h_img)
                x2 = min(x1 + w_crop, w_img)
                y1 = max(y2 - h_crop, 0)
                x1 = max(x2 - w_crop, 0)
                crop_img = clip_img[:, :, y1:y2, x1:x2]
                crop_vfm_img = vfm_img[:, :, y1:y2, x1:x2]
                crop_sam_img = sam_img[:, :, y1:y2, x1:x2]

                crop_gt_img = gt_img[y1:y2, x1:x2]

                # pad image when (image_size % patch_size != 0)
                H, W = crop_img.shape[2:]
                pad = self.compute_padsize(H, W, self.clip_patch_size)

                if any(pad):
                    crop_img = nn.functional.pad(crop_img, pad)

                crop_seg_logit = self.forward_feature(crop_img, crop_vfm_img, crop_sam_img)
                # mask cutting for padded image
                if any(pad):
                    l, t = pad[0], pad[2]
                    crop_seg_logit = crop_seg_logit[:, :, t:t + H, l:l + W]

                preds += nn.functional.pad(crop_seg_logit,
                                           (int(x1), int(preds.shape[3] - x2), int(y1),
                                            int(preds.shape[2] - y2)))
                crop_id += 1
                count_mat[:, :, y1:y2, x1:x2] += 1
        assert (count_mat == 0).sum() == 0

        preds = preds / count_mat
        img_size = img_metas[0]['ori_shape'][:2]
        logits = nn.functional.interpolate(preds, size=img_size, mode='bilinear')

        return logits

    @torch.no_grad()
    def predict(self, inputs, data_samples):
        if data_samples is not None:
            batch_img_metas = [
                data_sample.metainfo for data_sample in data_samples
            ]
        else:
            batch_img_metas = [
                                  dict(
                                      ori_shape=inputs.shape[2:],
                                      img_shape=inputs.shape[2:],
                                      pad_shape=inputs.shape[2:],
                                      padding_size=[0, 0, 0, 0])
                              ] * inputs.shape[0]

        sam_img_inputs = inputs.clone()

        clip_img_inputs = torch.stack([self.clip_transform(clip_input / 255) for clip_input in inputs], dim=0)
        vfm_img_inputs = torch.stack([self.vfm_transform(clip_input / 255) for clip_input in inputs], dim=0)

        if self.slide_crop > 0:
            seg_logits = self.forward_slide(clip_img_inputs, vfm_img_inputs, sam_img_inputs, batch_img_metas, self.slide_stride, self.slide_crop)
        else:
            gt_img = np.array(Image.open(batch_img_metas[0]["seg_map_path"]))
            gt_img = mmcv.imrescale(gt_img, scale=(448,448), interpolation="nearest", backend="cv2")
            seg_logits = self.forward_feature(clip_img_inputs, vfm_img_inputs, sam_img_inputs, batch_img_metas[0]["ori_shape"])

        return self.postprocess_result(seg_logits, data_samples)

    def postprocess_result(self, seg_logits, data_samples):
        batch_size = seg_logits.shape[0]
        for i in range(batch_size):
            seg_logits = seg_logits[i] * self.logit_scale
            seg_logits = seg_logits.softmax(0)  # n_queries * w * h

            num_cls, num_queries = max(self.query_idx) + 1, len(self.query_idx)
            if num_cls != num_queries:
                seg_logits = seg_logits.unsqueeze(0)
                cls_index = nn.functional.one_hot(self.query_idx)
                cls_index = cls_index.T.view(num_cls, num_queries, 1, 1)
                seg_logits = (seg_logits * cls_index).max(1)[0]

            seg_pred = seg_logits.argmax(0, keepdim=True)
            seg_pred[seg_logits.max(0, keepdim=True)[0] < self.prob_thd] = self.bg_idx


            if data_samples is None:
                return seg_pred
            else:
                data_samples[i].set_data({
                    'seg_logits':
                        PixelData(**{'data': seg_logits}),
                    'pred_sem_seg':
                        PixelData(**{'data': seg_pred})
                })
        return data_samples

    def compute_padsize(self, H: int, W: int, patch_size: int):
        l, r, t, b = 0, 0, 0, 0
        if W % patch_size:
            lr = patch_size - (W % patch_size)
            l = lr // 2
            r = lr - l

        if H % patch_size:
            tb = patch_size - (H % patch_size)
            t = tb // 2
            b = tb - t

        return l, r, t, b

    def _forward(data_samples):
        """
        """

    def inference(self, img, batch_img_metas):
        """
        """

    def encode_decode(self, inputs, batch_img_metas):
        """
        """

    def extract_feat(self, inputs):
        """
        """

    def loss(self, inputs, data_samples):
        """
        """


def get_cls_idx(path):
    with open(path, 'r') as f:
        name_sets = f.readlines()
    num_cls = len(name_sets)

    class_names, class_indices = [], []
    for idx in range(num_cls):
        names_i = name_sets[idx].split(',')
        class_names += names_i
        class_indices += [idx for _ in range(len(names_i))]
    class_names = [item.replace('\n', '') for item in class_names]
    return class_names, class_indices