from itertools import chain
from packaging import version
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch_scatter
import torchvision.transforms
from timm.models.layers import trunc_normal_
from torch.nn.utils import weight_norm
import sonata
import os
from transformers import AutoConfig, ViTModel, ViTConfig
from transformers import AutoModel, AutoProcessor
import open_clip
import json
import clip
import pointops
from pointcept.models.losses import build_criteria
from pointcept.models.utils.structure import Point
from pointcept.models.builder import MODELS, build_model
from pointcept.models.modules import PointModel
from pointcept.models.utils import (
    offset2batch,
    offset2bincount,
    batch2offset,
    bincount2offset,
)
from pointcept.utils.comm import get_world_size, all_gather
from pointcept.utils.scheduler import CosineScheduler
from pointcept.models.utils.visualize import Visualizer2D, Visualizer3D
from pointcept.models.concerto_lseg_pro.models.lseg_net import LSegNet
import time

import torch.nn as nn


class MLP(nn.Module):
    def __init__(
        self, input_dim, output_dim, hidden_dims=[256], activation="ReLU", dropout=0.1
    ):
        super().__init__()
        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(getattr(nn, activation)())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, output_dim))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


@MODELS.register_module("LSegImgAligner_MSC")
class LSegImgAligner(nn.Module):
    def __init__(
        self,
        image_weight_name,
        image_weight_path,
        backbone,
        backbone_out_channels,
        patch_w,
        patch_h,
        DINOhead_in_channels=256,
        up_cast_level=2,
        freeze_backbone=True,
        labels_list=[],
        ignore_index=(-1,),
        criteria=None,
        enc_mode=False,
        sonata_load=False,
        concerto_path=None,
        MLP_mode=False,
    ):
        super().__init__()
        # super(LSegImgAligner, self).__init__()
        self.labels_list = labels_list
        self.ignore_index = torch.tensor(ignore_index).cuda()
        # up cast level
        self.up_cast_level = up_cast_level
        self.enc_mode = enc_mode

        self.patch_h = patch_h
        self.patch_w = patch_w
        # Load Model
        self.lseg_model = self.load_lseg(image_weight_name, image_weight_path)
        self.lseg_model.requires_grad_(False)
        self.clip_text_encoder = self.lseg_model.clip_pretrained.encode_text
        # for name, param in self.lseg_model.named_parameters():
        #     print(f"{name}: {param.requires_grad}")

        self._num_channels = DINOhead_in_channels
        if self.enc_mode and MLP_mode:
            self.patch_proj = MLP(
                input_dim=backbone_out_channels,
                output_dim=self._num_channels,
                hidden_dims=[512, 256],
                activation="ReLU",
                dropout=0.1,
            )
        else:
            self.patch_proj = torch.nn.Linear(backbone_out_channels, self._num_channels)
        self.backbone = build_model(backbone)
        if sonata_load:
            self.backbone = self.load_sonata(self.backbone)
        elif concerto_path:
            self.backbone = self.load_sonata(self.backbone, path=concerto_path)
        self.criteria = build_criteria(criteria)
        self.CLIPText_forward()
        self.freeze_backbone = freeze_backbone
        if self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def load_lseg(
        self,
        model_name,
        model_weight="openclip/open_clip_pytorch_model.bin",
    ):
        net = LSegNet(
            labels=self.labels_list,
            backbone=model_name,
            features=256,
            crop_size=480,
            arch_option=0,
            block_depth=0,
            activation="lrelu",
        )
        weight = torch.load(str(model_weight))["state_dict"]
        weight_key_list = list(weight.keys())
        for key in weight_key_list:
            if key.startswith("net."):
                new_key = key.replace("net.", "")
                weight[new_key] = weight[key]
                del weight[key]
        net.load_state_dict(weight)
        net.eval()
        net.cuda()
        return net

    def load_sonata(
        self,
        model,
        path="sonata_release/ckpt/sonata.pth",
    ):
        if path == None:
            path = "sonata_release/ckpt/sonata.pth"
        checkpoint = torch.load(path, map_location=lambda storage, loc: storage.cuda())
        name_test = [n for n, p in model.named_parameters()]
        # print(name_test)
        weight = {}
        whether_weight = False
        if "state_dict" in checkpoint.keys():
            checkpoint = checkpoint["state_dict"]
            for key, value in checkpoint.items():
                if "module.backbone." in key:
                    whether_weight = True
                    key = key.replace("module.backbone.", "module.")
                    key = key[7:]  # module.xxx.xxx -> xxx.xxx
                    weight[key] = value
        if whether_weight:
            load_state_info = model.load_state_dict(weight, strict=False)
        else:
            load_state_info = model.load_state_dict(checkpoint, strict=False)
        print(f"Missing keys: {load_state_info[0]}")
        print(f"Unexpected keys: {load_state_info[1]}")
        # exit()
        return model

    @torch.no_grad()
    def LSeg_forward(self, x):
        if x.shape[0] > 0:
            features = self.lseg_model(x)
            features = features.permute(0, 2, 3, 1)
            # features = outputs.last_hidden_state
            # pooled_output = outputs.pooler_output
            # features = self.lseg_model.forward_intermediates(x)
            # features = features['image_intermediates'][-1]
        else:
            features = torch.empty(
                (0, self.patch_h, self.patch_w, self._num_channels), dtype=torch.float32
            ).cuda()
        return features

    @torch.no_grad()
    def CLIPText_forward(self):
        with torch.no_grad():
            prompt = [clip.tokenize(lc).cuda() for lc in self.labels_list]
            text_feat_list = [self.clip_text_encoder(p) for p in prompt]
            self.text_embeddings = torch.cat(
                [torch.nn.functional.normalize(tf) for tf in text_feat_list], dim=0
            )

    def up_cast(self, point, normalize=False, upcast_level=None):
        if upcast_level is None:
            upcast_level = self.up_cast_level
        else:
            upcast_level = upcast_level
        for _ in range(upcast_level):
            assert "pooling_parent" in point.keys()
            assert "pooling_inverse" in point.keys()
            parent = point.pop("pooling_parent")
            inverse = point.pop("pooling_inverse")
            # parent = point["pooling_parent"]
            # inverse = point["pooling_inverse"]
            parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
            point = parent
        return point

    def forward(self, data_dict, return_point=False):
        point = Point(data_dict)
        if self.freeze_backbone:
            with torch.no_grad():
                feat = self.backbone(point)
        else:
            feat = self.backbone(point)

        # train
        feat = self.patch_proj(feat)
        result_dict = {}
        # name = data_dict["name"]
        if self.training:
            offset_points_3d = data_dict["offset"]
            batch_points_3d = offset2batch(offset_points_3d)
            imgs = data_dict["imgs"]

            feature3d = feat

            mask_index = data_dict["mask_index"]
            v0 = mask_index.shape[1]
            mask = torch.any(mask_index != torch.tensor([-1, -1]).cuda(), dim=2)
            valid_index = torch.where(mask)  # 0: 3d points index, 1: view index

            bincount_img_num = data_dict["img_num"]
            offset_img_num = bincount2offset(bincount_img_num)
            total_img_num = offset_img_num[-1]
            # expand
            with torch.no_grad():
                feature2d = self.LSeg_forward(imgs)
                view_num = feature2d.shape[0]
                feature2d = feature2d.contiguous().view(-1, feature2d.shape[-1])
                feature2d_mask = feature2d
                # print(feature2d_mask.shape)
            if total_img_num > 0:
                offset_img_num = torch.cat([torch.tensor([0]).cuda(), offset_img_num])[
                    :-1
                ]
                batch_index = batch_points_3d[valid_index[0]]
                batch_img_num = offset_img_num[batch_index]

                feature3d_mask = feature3d[valid_index[0]]

                feature_index = torch.cat(
                    [
                        # batch_points_3d[valid_index[0]].unsqueeze(-1),
                        batch_img_num.unsqueeze(-1),
                        valid_index[1].unsqueeze(-1),
                        mask_index[valid_index],
                    ],
                    dim=-1,
                )
                feature_index = feature_index.long()
                feature_index = (
                    feature_index[:, 0] * self.patch_h * self.patch_w
                    # feature_index[:, 0] * v0 * self.patch_h * self.patch_w
                    + feature_index[:, 1] * self.patch_h * self.patch_w
                    + feature_index[:, 2] * self.patch_w
                    + feature_index[:, 3]
                )

                feature_index = feature_index.long()
                feature3d_mask = torch_scatter.scatter_mean(
                    feature3d_mask, feature_index, dim=0, dim_size=feature2d.shape[0]
                )
                # feature3d_mask = self.patch_proj(feature3d_mask)
                feature_index = torch.unique(feature_index)
                feature2d_mask = feature2d_mask[feature_index]
                feature3d_mask = feature3d_mask[feature_index]

                cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
                loss = (1 - cos(feature2d_mask, feature3d_mask)).mean()  # * 10

                # MSE = torch.nn.MSELoss()
                # loss = MSE(feature2d_mask, feature3d_mask).mean()# * 10
                result_dict["loss"] = loss
            else:
                result_dict["loss"] = 0.0
                # dinov2_mask_sample = torch.tensor(0.0, dtype=torch.float32).cuda()
                # result_dict["dinov2_loss"] = torch.tensor(0, device=imgs.device)
            torch.cuda.empty_cache()

            if get_world_size() > 1:
                for loss in result_dict.values():
                    dist.all_reduce(loss, op=dist.ReduceOp.AVG)
        # eval
        elif "segment" in data_dict.keys():
            self.text_embeddings = self.text_embeddings.to(feat.dtype)
            feat = torch.nn.functional.normalize(feat, dim=1)

            similarities = F.cosine_similarity(
                self.text_embeddings.unsqueeze(0),  # [1, 20, 512]
                feat.unsqueeze(1),  # [N, 1, 512]
                dim=2,
            )
            softmax_scores = F.softmax(similarities, dim=0).squeeze()
            # max_index = torch.argmax(softmax_scores,dim=-1)
            result_dict["seg_logits"] = softmax_scores
            loss = self.criteria(softmax_scores, data_dict["segment"])
            result_dict["loss"] = loss

        # test
        else:
            self.text_embeddings = self.text_embeddings.to(feat.dtype)
            feat = torch.nn.functional.normalize(feat, dim=1)
            similarities = F.cosine_similarity(
                self.text_embeddings.unsqueeze(0),  # [1, 20, 512]
                feat.unsqueeze(1),  # [N, 1, 512]
                dim=2,
            )
            softmax_scores = F.softmax(similarities, dim=0).squeeze()
            # max_index = torch.argmax(softmax_scores,dim=-1)
            result_dict["seg_logits"] = softmax_scores
        return result_dict
