import copy
import timm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import logging
import numpy as np
import clip
import json
import os

from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.common.pytorch_util import replace_submodules
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
from diffusion_policy.model.diffusion.common_model import DepthOnlyFCBackbone224x224, Fusion, FiLM

logger = logging.getLogger(__name__)


class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )
        return x.squeeze(0)


class TimmObsEncoder(ModuleAttrMixin):
    def __init__(self,
                 shape_meta: dict,
                 model_name: str,
                 pretrained: bool,
                 frozen: bool,
                 global_pool: str,
                 transforms: list,
                 # replace BatchNorm with GroupNorm
                 use_group_norm: bool = False,
                 # use single rgb model for all rgb inputs
                 share_rgb_model: bool = False,
                 # renormalize rgb input with imagenet normalization
                 # assuming input in [0,1]
                 imagenet_norm: bool = False,
                 feature_aggregation: str = 'spatial_embedding',
                 downsample_ratio: int = 32,
                 position_encording: str = 'sinusoidal',
                 use_text_fusion: bool = False,
                 vl_fusion: str = 'ResNet_FiLM',
                 use_interpret: bool = False,
                 ):
        """
        Assumes rgb input: B,T,C,H,W
        Assumes depth input: B,T,C,H,W
        Assumes low_dim input: B,T,D
        """
        super().__init__()

        rgb_keys = list()
        depth_keys = list()
        pcl_keys = list()
        low_dim_keys = list()
        text_keys = list()
        key_model_map = nn.ModuleDict()
        key_transform_map = nn.ModuleDict()
        key_shape_map = dict()

        assert global_pool == ''
        model = timm.create_model(
            model_name=model_name,
            pretrained=pretrained,
            global_pool=global_pool,  # '' means no pooling
            num_classes=0  # remove classification layer
        )

        depth_model = DepthOnlyFCBackbone224x224(256)#
        # ViT-L/14 ViT-B/16
        text_model, preprocess = clip.load("ViT-B/16", device=self.device)

        if frozen:
            assert pretrained
            for param in model.parameters():
                param.requires_grad = False

        feature_dim = None
        if model_name.startswith('resnet'):
            # the last layer is nn.Identity() because num_classes is 0
            # second last layer is AdaptivePool2d, which is also identity because global_pool is empty
            if downsample_ratio == 32:
                modules = list(model.children())[:-2]
                model = torch.nn.Sequential(*modules)
                feature_dim = 512
            elif downsample_ratio == 16:
                modules = list(model.children())[:-3]
                model = torch.nn.Sequential(*modules)
                feature_dim = 256
            else:
                raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}")
        elif model_name.startswith('convnext'):
            # the last layer is nn.Identity() because num_classes is 0
            # second last layer is AdaptivePool2d, which is also identity because global_pool is empty
            if downsample_ratio == 32:
                modules = list(model.children())[:-2]
                model = torch.nn.Sequential(*modules)
                feature_dim = 1024
            else:
                raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}")

        if use_group_norm and not pretrained:
            model = replace_submodules(
                root_module=model,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=(x.num_features // 16) if (x.num_features % 16 == 0) else (x.num_features // 8),
                    num_channels=x.num_features)
            )

        image_shape = None
        obs_shape_meta = shape_meta['obs']
        for key, attr in obs_shape_meta.items():
            shape = tuple(attr['shape'])
            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                assert image_shape is None or image_shape == shape[1:]
                image_shape = shape[1:]
        if transforms is not None and not isinstance(transforms[0], torch.nn.Module):
            assert transforms[0].type == 'RandomCrop'
            ratio = transforms[0].ratio
            transforms = [
                             torchvision.transforms.RandomCrop(size=int(image_shape[0] * ratio)),
                             torchvision.transforms.Resize(size=image_shape[0], antialias=True)
                         ] + transforms[1:]
        transform = nn.Identity() if transforms is None else torch.nn.Sequential(*transforms)

        for key, attr in obs_shape_meta.items():
            shape = tuple(attr['shape'])
            type = attr.get('type', 'low_dim')
            key_shape_map[key] = shape
            if type == 'rgb':
                rgb_keys.append(key)

                this_model = model if share_rgb_model else copy.deepcopy(model)
                key_model_map[key] = this_model

                this_transform = transform
                key_transform_map[key] = this_transform
            elif type == 'depth':
                depth_keys.append(key)
                key_model_map[key] = depth_model
            elif type == 'pcl':
                pcl_keys.append(key)
                key_model_map[key] = pcl_model
            elif type == 'low_dim':
                if not attr.get('ignore_by_policy', False):
                    low_dim_keys.append(key)
            elif type == 'text':
                key_model_map[key] = text_model
                if not attr.get('ignore_by_policy', False):
                    text_keys.append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        feature_map_shape = [x // downsample_ratio for x in image_shape]

        rgb_keys = sorted(rgb_keys)
        depth_keys = sorted(depth_keys)
        pcl_keys = sorted(pcl_keys)
        low_dim_keys = sorted(low_dim_keys)
        text_keys = sorted(text_keys)
        print('rgb keys:         ', rgb_keys)
        print('depth keys:       ', depth_keys)
        print('pcl keys:       ', pcl_keys)
        print('low_dim_keys keys:', low_dim_keys)
        print('text keys:        ', text_keys)

        # cond encoder
        self.time_emb = SinusoidalPosEmb(256)
        self.cond_obs_emb = None

        self.model_name = model_name
        self.shape_meta = shape_meta
        self.key_model_map = key_model_map
        self.key_transform_map = key_transform_map
        self.share_rgb_model = share_rgb_model
        self.rgb_keys = rgb_keys
        self.depth_keys = depth_keys
        self.pcl_keys = pcl_keys
        self.low_dim_keys = low_dim_keys
        self.text_keys = text_keys
        self.key_shape_map = key_shape_map
        self.feature_aggregation = feature_aggregation
        self.use_text_fusion = use_text_fusion
        self.fusion_layer = None
        self.use_interpret = use_interpret

        if use_text_fusion:
            if vl_fusion == 'fusion':
                self.fusion_layer = Fusion(obs_emd_size=768, lang_emd_size=512)
                print("USE fusion")
            elif vl_fusion == 'ResNet_FiLM':
                self.fusion_layer = Res_FiLM(n_res_blocks=12, input_dim=1024, condition_dim=512)
                print("USE ResNet_FiLM")
            elif vl_fusion == 'Cross_Attention':
                self.fusion_layer = CrossAttentionLayer(input_dim_1=768, input_dim_2=512, hidden_dim=1024)
                print("USE Cross_Attention")
            elif vl_fusion == 'FiLM':
                self.fusion_layer = FiLM(input_dim=768, condition_dim=512)
                print("USE FiLM")
            else:
                raise NotImplementedError(f"Unsupported: {vl_fusion}")
        else:
            print("Do Not USE text fusion")

        if model_name.startswith('vit'):
            # assert self.feature_aggregation is None # vit uses the CLS token
            if self.feature_aggregation == 'CLS_token':
                # Use all tokens from ViT
                pass
            elif self.feature_aggregation == 'weight_aggregation':
                self.weight_aggregation = WeightedAggregation(input_size=768)
            elif self.feature_aggregation is not None:
                logger.warn(f'vit will use the CLS token. feature_aggregation ({self.feature_aggregation}) is ignored!')
                self.feature_aggregation = 'CLS_token'

        if self.feature_aggregation == 'soft_attention':
            self.attention = nn.Sequential(
                nn.Linear(feature_dim, 1, bias=False),
                nn.Softmax(dim=1)
            )
        elif self.feature_aggregation == 'spatial_embedding':
            self.feature_aggregation = torch.nn.Parameter(
                torch.randn(feature_map_shape[0] * feature_map_shape[1], feature_dim))
        elif self.feature_aggregation == 'transformer':
            if position_encording == 'learnable':
                self.position_embedding = torch.nn.Parameter(
                    torch.randn(feature_map_shape[0] * feature_map_shape[1] + 1, feature_dim))
            elif position_encording == 'sinusoidal':
                num_features = feature_map_shape[0] * feature_map_shape[1] + 1
                self.position_embedding = torch.zeros(num_features, feature_dim)
                position = torch.arange(0, num_features, dtype=torch.float).unsqueeze(1)
                div_term = torch.exp(
                    torch.arange(0, feature_dim, 2).float() * (-math.log(2 * num_features) / feature_dim))
                self.position_embedding[:, 0::2] = torch.sin(position * div_term)
                self.position_embedding[:, 1::2] = torch.cos(position * div_term)
            self.aggregation_transformer = nn.TransformerEncoder(
                encoder_layer=nn.TransformerEncoderLayer(d_model=feature_dim, nhead=4),
                num_layers=4)
        elif self.feature_aggregation == 'attention_pool_2d':
            self.attention_pool_2d = AttentionPool2d(
                spacial_dim=feature_map_shape[0],
                embed_dim=feature_dim,
                num_heads=feature_dim // 64,
                output_dim=feature_dim
            )
        logger.info(
            "number of parameters: %e", sum(p.numel() for p in self.parameters())
        )

    def aggregate_feature(self, feature):
        if self.model_name.startswith('vit'):
            # assert self.feature_aggregation is None # vit uses the CLS token
            if self.feature_aggregation == 'CLS_token':
                return feature[:, 0, :]
            elif self.feature_aggregation == 'weight_aggregation':
                return self.weight_aggregation(feature)

        # resnet
        assert len(feature.shape) == 4
        if self.feature_aggregation == 'attention_pool_2d':
            return self.attention_pool_2d(feature)

        feature = torch.flatten(feature, start_dim=-2)  # B, 512, 7*7
        feature = torch.transpose(feature, 1, 2)  # B, 7*7, 512

        if self.feature_aggregation == 'avg':
            return torch.mean(feature, dim=[1])
        elif self.feature_aggregation == 'max':
            return torch.amax(feature, dim=[1])
        elif self.feature_aggregation == 'soft_attention':
            weight = self.attention(feature)
            return torch.sum(feature * weight, dim=1)
        elif self.feature_aggregation == 'spatial_embedding':
            return torch.mean(feature * self.spatial_embedding, dim=1)
        elif self.feature_aggregation == 'transformer':
            zero_feature = torch.zeros(feature.shape[0], 1, feature.shape[-1], device=feature.device)
            if self.position_embedding.device != feature.device:
                self.position_embedding = self.position_embedding.to(feature.device)
            feature_with_pos_embedding = torch.concat([zero_feature, feature], dim=1) + self.position_embedding
            feature_output = self.aggregation_transformer(feature_with_pos_embedding)
            return feature_output[:, 0]
        else:
            assert self.feature_aggregation is None
            return feature

    def forward(self, obs_dict):
        features = list()
        batch_size = next(iter(obs_dict.values())).shape[0]

        # process rgb input
        for key in self.rgb_keys:
            img = obs_dict[key]
            B, T = img.shape[:2]
            assert B == batch_size
            assert img.shape[2:] == self.key_shape_map[key]
            img = img.reshape(B * T, *img.shape[2:])
            # print(f'rgb_img {key}: ', img.shape)
            img = self.key_transform_map[key](img)
            raw_feature = self.key_model_map[key](img)

            assert len(raw_feature.shape) == 3 and raw_feature.shape[0] == B * T
            device = raw_feature.device

            if self.use_text_fusion and not self.text_keys == []:
                for key in self.text_keys:
                    data = obs_dict[key]
                    B, T = data.shape[:2]
                    texts = data.reshape(B * T, -1).tolist()
                    texts_s = []
                    for word_idx in texts:
                        text = []
                        for char_code in word_idx:
                            text.append(chr(int(char_code)))
                        texts_s.append(''.join(text))
                    texts = np.array(texts_s).reshape(B * T).tolist()
                    texts = [text.strip() for text in texts]
                    ori_text = texts

                    texts = clip.tokenize(texts).to(device)
                    with torch.no_grad():
                        text_features = self.key_model_map[key].encode_text(texts)
                        text_features = text_features.clone().detach()
                        text_embeds = text_features.to(torch.float32).to(device)

                    # print("text_embeds: ",text_embeds.shape)

                    if len(text_embeds.shape) == 3:
                        raw_coss_feature = self.fusion_layer(raw_feature, text_embeds)
                        cross_feature = self.aggregate_feature(raw_coss_feature)
                        features.append(cross_feature.reshape(B, -1))
                        # print("text:", texts)
                    elif len(text_embeds.shape) == 2:
                        feature = self.aggregate_feature(raw_feature)
                        fusion_feature = self.fusion_layer(feature, text_embeds)
                        features.append(fusion_feature.reshape(B, -1))
            else:
                feature = self.aggregate_feature(raw_feature)
                features.append(feature.reshape(B, -1))

        for key in self.depth_keys:
            img = obs_dict[key]
            # print('depth_img: ', img.shape)
            B, T = img.shape[:2]
            assert B == batch_size
            assert img.shape[2:] == self.key_shape_map[key]
            img = img.reshape(B * T, *img.shape[2:])
            raw_feature = self.key_model_map[key](img)
            assert len(raw_feature.shape) == 2 and raw_feature.shape[0] == B * T
            # print(key, raw_feature)
            features.append(raw_feature.reshape(B, -1))

        for key in self.pcl_keys:
            pointcloud = obs_dict[key]
            B, T = pointcloud.shape[:2]
            assert B == batch_size
            assert pointcloud.shape[2:] == self.key_shape_map[key]
            pointcloud = pointcloud.reshape(B * T, *pointcloud.shape[2:])
            raw_feature = self.key_model_map[key](pointcloud)
            assert len(raw_feature.shape) == 2 and raw_feature.shape[0] == B * T
            # print(key, raw_feature)
            features.append(raw_feature.reshape(B, -1))

        # process lowdim input
        for key in self.low_dim_keys:
            data = obs_dict[key]
            B, T = data.shape[:2]
            assert B == batch_size
            assert data.shape[2:] == self.key_shape_map[key]
            if key == 'timestep':  # (B, T, 1)
                time_emb = self.time_emb(data)
                # print(key, data.shape)
                features.append(time_emb.reshape(B, -1))
                # print(key, time_emb.shape)
            else:
                # print(key, data.shape)
                features.append(data.reshape(B, -1))

        text_embeds = None
        # 
        # if not self.use_text_fusion:
        #     for key in self.text_keys:
        #         data = obs_dict[key]
        #         B, T = data.shape[:2]
        #         texts = data.reshape(B * T, -1).tolist()
        #         texts_s = []
        #         for word_idx in texts:
        #             text = []
        #             for char_code in word_idx:
        #                 text.append(chr(int(char_code)))
        #             texts_s.append(''.join(text))
        #         # texts = np.array(texts_s).reshape(B, -1)
        #         # print("do not use fusion")
        #         # print("text:", texts)
        #         texts = np.array(texts_s).reshape(B * T).tolist()
        #         texts = [text.strip() for text in texts]
        #         # texts = str(texts).strip()
        #         # print("Encoder get text:", texts)
        #         texts = clip.tokenize(texts).to(device)
        #         with torch.no_grad():
        #             text_features = self.key_model_map[key].encode_text(texts)
        #             text_features = text_features.clone().detach()
        #             text_embeds = text_features.to(torch.float32).to(device)

        #         text_embeds = text_embeds.reshape(B, -1)

        # concatenate all features
        result = torch.cat(features, dim=-1)
        # print("result:", result.shape)

        return result
        # return result, text_embeds


    @torch.no_grad()
    def output_shape(self):
        example_obs_dict = dict()
        obs_shape_meta = self.shape_meta['obs']
        for key, attr in obs_shape_meta.items():
            shape = tuple(attr['shape'])
            this_obs = torch.zeros(
                (1, attr['horizon']) + shape,
                dtype=self.dtype,
                device=self.device)
            example_obs_dict[key] = this_obs
        example_output = self.forward(example_obs_dict)

        assert len(example_output.shape) == 2
        assert example_output.shape[0] == 1
        shape1 = example_output.shape
        # shape2 = example_output2.shape if example_output2 is not None else None

        return shape1


def save_features(ori_texts, features, output_file, min_data_num):
    all_data = []

    if os.path.exists(output_file):
        with open(output_file, "r", encoding="utf-8") as f:
            old_data = json.load(f)
            if len(old_data) < min_data_num:
                all_data.extend(old_data)
            else:
                print("save finish")
                return

    data = [{"text": line, "features": feature} for line, feature in zip(ori_texts, features.tolist()) if line]
    all_data.extend(data)  # 

    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(all_data, f, indent=4)
        print("saved data num: ", len(all_data))


if __name__ == '__main__':
    timm_obs_encoder = TimmObsEncoder(
        shape_meta=None,
        model_name='resnet18.a1_in1k',
        pretrained=False,
        global_pool='',
        transforms=None
    )
