"""
For test zero-shot
"""

from itertools import chain
from packaging import version
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch_scatter
import torchvision.transforms
from timm.models.layers import trunc_normal_
from torch.nn.utils import weight_norm
import sonata
import os
from transformers import AutoConfig, ViTModel, ViTConfig
from transformers import AutoModel, AutoProcessor
import open_clip
import json
import clip
import pointops
from pointcept.models.losses import build_criteria
from pointcept.models.utils.structure import Point
from pointcept.models.builder import MODELS, build_model
from pointcept.models.modules import PointModel
from pointcept.models.utils import (
    offset2batch,
    offset2bincount,
    batch2offset,
    bincount2offset,
)
from pointcept.utils.comm import get_world_size, all_gather
from pointcept.utils.scheduler import CosineScheduler
from pointcept.models.utils.visualize import Visualizer2D, Visualizer3D
from pointcept.models.concerto_lseg_pro.models.lseg_net import LSegNet
import time

import torch.nn as nn


class MLP(nn.Module):
    def __init__(
        self, input_dim, output_dim, hidden_dims=[256], activation="ReLU", dropout=0.1
    ):
        super().__init__()
        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(getattr(nn, activation)())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, output_dim))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


@MODELS.register_module("CLIPZSSegmentorV2")
class CLIPZSSegmentorV2(nn.Module):
    def __init__(
        self,
        text_weight_name,
        text_weight_path,
        backbone_out_channels,
        DINOhead_in_channels=256,
        up_cast_level=2,
        backbone=None,
        enc_mode=False,
        freeze_backbone=False,
        labels_list=[],
        MLP_mode=False,
    ):
        super().__init__()
        self.labels_list = labels_list
        self.up_cast_level = up_cast_level
        self.enc_mode = enc_mode
        self.lseg_model = self.load_lseg(text_weight_name, text_weight_path)
        self.lseg_model.requires_grad_(False)
        self.clip_text_encoder = self.lseg_model.clip_pretrained.encode_text
        self._num_channels = DINOhead_in_channels
        if self.enc_mode and MLP_mode:
            self.patch_proj = MLP(
                input_dim=backbone_out_channels,
                output_dim=self._num_channels,
                hidden_dims=[512, 256],
                activation="ReLU",
                dropout=0.1,
            )
        else:
            self.patch_proj = torch.nn.Linear(backbone_out_channels, self._num_channels)
        self.patch_proj.requires_grad_(False)
        self.backbone = build_model(backbone)
        self.CLIPText_forward()
        self.freeze_backbone = freeze_backbone
        if self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def load_lseg(
        self,
        model_name,
        model_weight="openclip/open_clip_pytorch_model.bin",
    ):
        net = LSegNet(
            labels=self.labels_list,
            backbone=model_name,
            features=256,
            crop_size=480,
            arch_option=0,
            block_depth=0,
            activation="lrelu",
        )
        weight = torch.load(str(model_weight))["state_dict"]
        weight_key_list = list(weight.keys())
        for key in weight_key_list:
            if key.startswith("net."):
                new_key = key.replace("net.", "")
                weight[new_key] = weight[key]
                del weight[key]
        net.load_state_dict(weight)
        net.eval()
        net.cuda()
        return net

    def CLIPText_forward(self):
        with torch.no_grad():
            prompt = [clip.tokenize(lc).cuda() for lc in self.labels_list]
            text_feat_list = [self.clip_text_encoder(p) for p in prompt]
            self.text_embeddings = torch.cat(
                [torch.nn.functional.normalize(tf) for tf in text_feat_list], dim=0
            )

    def up_cast(self, point, normalize=False, upcast_level=None):
        if upcast_level is None:
            upcast_level = self.up_cast_level
        else:
            upcast_level = upcast_level
        for _ in range(upcast_level):
            assert "pooling_parent" in point.keys()
            assert "pooling_inverse" in point.keys()
            parent = point.pop("pooling_parent")
            inverse = point.pop("pooling_inverse")
            # parent = point["pooling_parent"]
            # inverse = point["pooling_inverse"]
            parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
            point = parent
        return point

    def forward(self, data_dict, return_point=False):
        point = Point(data_dict)
        if self.freeze_backbone:
            with torch.no_grad():
                point = self.backbone(point)
        else:
            point = self.backbone(point)

        if self.enc_mode:
            point_ = self.up_cast(point)
            feat = point_.feat
        else:
            inverse_list = []
            point_feat = [point.feat]
            if isinstance(point, Point):
                while "unpooling_parent" in point.keys():
                    parent = point.pop("unpooling_parent")
                    assert "pooling_inverse" in parent.keys()
                    inverse = parent.pop("pooling_inverse")
                    inverse_list.append(inverse)
                    # print([inverse_i.shape for inverse_i in inverse_list],parent.feat.shape, point.feat.shape)
                    for inverse_i in inverse_list[::-1]:
                        parent.feat = parent.feat[inverse_i]
                    point_feat.append(parent.feat)
                    point = parent
                feat = torch.cat(point_feat, dim=-1)
            else:
                feat = point
        feat = self.patch_proj(feat)
        self.text_embeddings = self.text_embeddings.to(feat.dtype)
        feat = torch.nn.functional.normalize(feat, dim=1)
        # distances = torch.cdist(feat, self.text_embeddings, p=2)
        # similarities = -distances
        batch_size = 1000
        similarities = []
        for i in range(0, feat.shape[0], batch_size):
            x_batch = feat[i : i + batch_size]  # [batch_size, 512]
            sim_batch = F.cosine_similarity(
                x_batch.unsqueeze(1), self.text_embeddings.unsqueeze(0), dim=2
            )
            similarities.append(sim_batch)

        similarities = torch.cat(similarities, dim=0)  # [N, 20]

        softmax_scores = F.softmax(similarities, dim=0).squeeze()
        result_dict = {}
        result_dict["seg_logits"] = softmax_scores
        return result_dict
