import copy
import numpy as np
import timm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import logging
import time
import clip

from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
# from transformers import CLIPProcessor, CLIPModel
from diffusion_policy.common.pytorch_util import replace_submodules
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb

logger = logging.getLogger(__name__)


class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )
        return x.squeeze(0)


class LangEmbedding(nn.Module):
    """
    project the embedded language into the dimension needed

    lang -> [batch, 512]
    out -> [batch, latent_dim_x]
    """

    def __init__(self, emd_size, input_channel=512):
        super(LangEmbedding, self).__init__()
        self.emd_size = emd_size
        self.linear1 = torch.nn.Linear(in_features=input_channel, out_features=emd_size)
        self.linear2 = torch.nn.Linear(in_features=emd_size, out_features=emd_size)
        self.cls = torch.nn.Linear(in_features=emd_size, out_features=3)

    def forward(self, lang):
        x = self.linear1(lang)
        x = F.relu(x)
        x = self.linear2(x)
        feat = F.relu(x)
        label = self.cls(feat)

        return x, label


class Fusion(nn.Module):
    def __init__(self, obs_emd_size, lang_emd_size):
        super(Fusion, self).__init__()
        self.input_size = obs_emd_size + lang_emd_size
        self.linear1 = torch.nn.Linear(
            in_features=self.input_size, out_features=obs_emd_size
        )
        self.linear2 = torch.nn.Linear(in_features=obs_emd_size, out_features=obs_emd_size)
        self.linear_add1 = torch.nn.Linear(in_features=lang_emd_size, out_features=lang_emd_size)

    def forward(self, obs, lang):
        x = self.linear_add1(lang)
        x = F.relu(x)
        x = torch.cat([obs, x], axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)

        return x


class FiLM(nn.Module):
    def __init__(self, input_dim, condition_dim):
        super(FiLM, self).__init__()
        self.fc_gamma = nn.Linear(condition_dim, input_dim)
        self.fc_beta = nn.Linear(condition_dim, input_dim)

    def forward(self, x, condition):
        gamma = self.fc_gamma(condition)
        beta = self.fc_beta(condition)
        y = gamma * x + beta

        return y


class ClipTimmObsEncoder(ModuleAttrMixin):
    def __init__(self,
                 shape_meta: dict,
                 model_name: str,
                 pretrained: bool,
                 frozen: bool,
                 global_pool: str,
                 transforms: list,
                 # replace BatchNorm with GroupNorm
                 use_group_norm: bool = False,
                 # use single rgb model for all rgb inputs
                 share_rgb_model: bool = False,
                 # renormalize rgb input with imagenet normalization
                 # assuming input in [0,1]
                 imagenet_norm: bool = False,
                 vl_fusion: str = 'fusion',
                 feature_aggregation: str = 'spatial_embedding',
                 downsample_ratio: int = 32,
                 position_encording: str = 'sinusoidal',
                 ):
        """
        Assumes rgb input: B,T,C,H,W
        Assumes low_dim input: B,T,D
        """
        super().__init__()

        rgb_keys = list()
        low_dim_keys = list()
        key_model_map = nn.ModuleDict()
        key_transform_map = nn.ModuleDict()
        key_shape_map = dict()

        assert global_pool == ''
        model = timm.create_model(
            model_name=model_name,
            pretrained=pretrained,
            global_pool=global_pool,  # '' means no pooling
            num_classes=0  # remove classification layer
        )

        if frozen:
            assert pretrained
            for param in model.parameters():
                param.requires_grad = False

        feature_dim = None
        if model_name.startswith('resnet'):
            # the last layer is nn.Identity() because num_classes is 0
            # second last layer is AdaptivePool2d, which is also identity because global_pool is empty
            if downsample_ratio == 32:
                modules = list(model.children())[:-2]
                model = torch.nn.Sequential(*modules)
                feature_dim = 512
            elif downsample_ratio == 16:
                modules = list(model.children())[:-3]
                model = torch.nn.Sequential(*modules)
                feature_dim = 256
            else:
                raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}")
        elif model_name.startswith('convnext'):
            # the last layer is nn.Identity() because num_classes is 0
            # second last layer is AdaptivePool2d, which is also identity because global_pool is empty
            if downsample_ratio == 32:
                modules = list(model.children())[:-2]
                model = torch.nn.Sequential(*modules)
                feature_dim = 1024
            else:
                raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}")

        if use_group_norm and not pretrained:
            model = replace_submodules(
                root_module=model,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=(x.num_features // 16) if (x.num_features % 16 == 0) else (x.num_features // 8),
                    num_channels=x.num_features)
            )

        image_shape = None
        obs_shape_meta = shape_meta['obs']
        for key, attr in obs_shape_meta.items():
            shape = tuple(attr['shape'])
            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                assert image_shape is None or image_shape == shape[1:]
                image_shape = shape[1:]
        if transforms is not None and not isinstance(transforms[0], torch.nn.Module):
            assert transforms[0].type == 'RandomCrop'
            ratio = transforms[0].ratio
            transforms = [
                             torchvision.transforms.RandomCrop(size=int(image_shape[0] * ratio)),
                             torchvision.transforms.Resize(size=image_shape[0], antialias=True)
                         ] + transforms[1:]
        transform = nn.Identity() if transforms is None else torch.nn.Sequential(*transforms)

        for key, attr in obs_shape_meta.items():
            shape = tuple(attr['shape'])
            type = attr.get('type', 'low_dim')
            key_shape_map[key] = shape
            if type == 'rgb':
                rgb_keys.append(key)

                this_model = model if share_rgb_model else copy.deepcopy(model)
                key_model_map[key] = this_model

                this_transform = transform
                key_transform_map[key] = this_transform
            elif type == 'low_dim':
                if not attr.get('ignore_by_policy', False):
                    low_dim_keys.append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        feature_map_shape = [x // downsample_ratio for x in image_shape]

        rgb_keys = sorted(rgb_keys)
        low_dim_keys = sorted(low_dim_keys)
        print('rgb keys:         ', rgb_keys)
        print('low_dim_keys keys:', low_dim_keys)

        self.model = model

        # self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        # self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_model, preprocess = clip.load("ViT-B/32", device=self.device)

        if vl_fusion == 'fusion':
            self.fusion_layer = Fusion(obs_emd_size=768, lang_emd_size=512)
        else:
            self.fusion_layer = FiLM(input_dim=768, condition_dim=512)

        # for param in self.clip_model.parameters():
        #     param.requires_grad = False

        # cond encoder
        self.time_emb = SinusoidalPosEmb(256)
        self.cond_obs_emb = None

        self.model_name = model_name
        self.shape_meta = shape_meta
        self.key_model_map = key_model_map
        self.key_transform_map = key_transform_map
        self.share_rgb_model = share_rgb_model
        self.rgb_keys = rgb_keys
        self.low_dim_keys = low_dim_keys
        self.key_shape_map = key_shape_map
        self.vl_fusion = vl_fusion
        self.feature_aggregation = feature_aggregation
        if model_name.startswith('vit'):
            # assert self.feature_aggregation is None # vit uses the CLS token
            if self.feature_aggregation == 'all_tokens':
                # Use all tokens from ViT
                pass
            elif self.feature_aggregation is not None:
                logger.warn(f'vit will use the CLS token. feature_aggregation ({self.feature_aggregation}) is ignored!')
                self.feature_aggregation = None

        if self.feature_aggregation == 'soft_attention':
            self.attention = nn.Sequential(
                nn.Linear(feature_dim, 1, bias=False),
                nn.Softmax(dim=1)
            )
        elif self.feature_aggregation == 'spatial_embedding':
            self.feature_aggregation = torch.nn.Parameter(
                torch.randn(feature_map_shape[0] * feature_map_shape[1], feature_dim))
        elif self.feature_aggregation == 'transformer':
            if position_encording == 'learnable':
                self.position_embedding = torch.nn.Parameter(
                    torch.randn(feature_map_shape[0] * feature_map_shape[1] + 1, feature_dim))
            elif position_encording == 'sinusoidal':
                num_features = feature_map_shape[0] * feature_map_shape[1] + 1
                self.position_embedding = torch.zeros(num_features, feature_dim)
                position = torch.arange(0, num_features, dtype=torch.float).unsqueeze(1)
                div_term = torch.exp(
                    torch.arange(0, feature_dim, 2).float() * (-math.log(2 * num_features) / feature_dim))
                self.position_embedding[:, 0::2] = torch.sin(position * div_term)
                self.position_embedding[:, 1::2] = torch.cos(position * div_term)
            self.aggregation_transformer = nn.TransformerEncoder(
                encoder_layer=nn.TransformerEncoderLayer(d_model=feature_dim, nhead=4),
                num_layers=4)
        elif self.feature_aggregation == 'attention_pool_2d':
            self.attention_pool_2d = AttentionPool2d(
                spacial_dim=feature_map_shape[0],
                embed_dim=feature_dim,
                num_heads=feature_dim // 64,
                output_dim=feature_dim
            )
        logger.info(
            "number of parameters: %e", sum(p.numel() for p in self.parameters())
        )

    def aggregate_feature(self, feature):
        if self.model_name.startswith('vit'):
            assert self.feature_aggregation is None  # vit uses the CLS token
            return feature[:, 0, :]

        # resnet
        assert len(feature.shape) == 4
        if self.feature_aggregation == 'attention_pool_2d':
            return self.attention_pool_2d(feature)

        feature = torch.flatten(feature, start_dim=-2)  # B, 512, 7*7
        feature = torch.transpose(feature, 1, 2)  # B, 7*7, 512

        if self.feature_aggregation == 'avg':
            return torch.mean(feature, dim=[1])
        elif self.feature_aggregation == 'max':
            return torch.amax(feature, dim=[1])
        elif self.feature_aggregation == 'soft_attention':
            weight = self.attention(feature)
            return torch.sum(feature * weight, dim=1)
        elif self.feature_aggregation == 'spatial_embedding':
            return torch.mean(feature * self.spatial_embedding, dim=1)
        elif self.feature_aggregation == 'transformer':
            zero_feature = torch.zeros(feature.shape[0], 1, feature.shape[-1], device=feature.device)
            if self.position_embedding.device != feature.device:
                self.position_embedding = self.position_embedding.to(feature.device)
            feature_with_pos_embedding = torch.concat([zero_feature, feature], dim=1) + self.position_embedding
            feature_output = self.aggregation_transformer(feature_with_pos_embedding)
            return feature_output[:, 0]
        else:
            assert self.feature_aggregation is None
            return feature

    def forward(self, obs_dict):
        features = list()
        batch_size = next(iter(obs_dict.values())).shape[0]
        # device = self.clip_model.device

        start_time = time.time()
        # process rgb input
        for key in self.rgb_keys:
            img = obs_dict[key]
            B, T = img.shape[:2]
            assert B == batch_size
            assert img.shape[2:] == self.key_shape_map[key]
            img = img.reshape(B * T, *img.shape[2:])
            img_ = self.key_transform_map[key](img)
            raw_feature = self.key_model_map[key](img_)
            feature = self.aggregate_feature(raw_feature)
            assert len(feature.shape) == 2 and feature.shape[0] == B * T
            # print(key, feature.shape)
            # features.append(feature.reshape(B, -1))
            device = feature.device

            text_start_time = time.time()
            texts = obs_dict["robot0_language_instruction"].reshape(B * T, -1).tolist()

            # texts = [''.join(chr(int(char_code)) for char_code in word_idx) for word_idx in texts]

            texts_s = []
            for word_idx in texts:
                text = []
                for char_code in word_idx:
                    if int(char_code) == 12290:
                        break
                    text.append(chr(int(char_code)))
                texts_s.append(''.join(text))
            texts = np.array(texts_s).reshape(B * T).tolist()
            print("++++++++++++++++++++++++++=texts", texts)

            texts = clip.tokenize(texts).to(device)
            # print("text_time", time.time() - text_start_time)

            clip_start_time = time.time()

            with torch.no_grad():
                # print(texts.device, feature.device)

                text_features = self.clip_model.encode_text(texts)
                text_features = text_features.clone().detach()
                text_embeds = text_features.to(torch.float32).to(device)

            # inputs = self.clip_processor(text=texts, images=img, return_tensors="pt", padding=True,
            #                         do_rescale=False)
            # print("clip_time_1", time.time() - clip_start_time)

            # inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
            # print("clip_time_2", time.time() - clip_start_time)

            # # if 'pixel_values' in inputs:
            # #     inputs.pop('pixel_values')

            # # text_embeds = self.clip_model.get_text_features(**inputs)

            # outputs = self.clip_model(**inputs)
            # text_embeds, image_embeds = outputs.text_embeds, outputs.image_embeds

            # assert len(text_embeds.shape) == 2 and text_embeds.shape[0] == B * T
            # assert len(image_embeds.shape) == 2 and image_embeds.shape[0] == B * T
            # print("clip_time_3", time.time() - clip_start_time)

            # print(text_embeds.shape, image_embeds.shape)
            if self.vl_fusion == 'fusion':
                emb = self.fusion_layer(feature, text_embeds)
            else:
                emb = self.fusion_layer(feature, text_embeds)
                print("FiLM fusion layer:", texts)

            # print("clip_time_4", time.time() - clip_start_time)

            features.append(emb.reshape(B, -1))

        # process lowdim input
        for key in self.low_dim_keys:
            if key == 'robot0_language_instruction' or key == 'robot1_language_instruction':
                continue
            data = obs_dict[key]
            B, T = data.shape[:2]
            assert B == batch_size
            assert data.shape[2:] == self.key_shape_map[key]
            if key == 'timestep':  # (B, T, 1)
                time_emb = self.time_emb(data)
                # print(key, data.shape)
                features.append(time_emb.reshape(B, -1))
                # print(key, time_emb.shape)
            else:
                # print(key, data.shape)
                features.append(data.reshape(B, -1))

        # concatenate all features
        result = torch.cat(features, dim=-1)

        # print("result:", result.shape)

        # print("result:", result.shape)
        # for name, param in self.model.named_parameters():
        #     if name == "logit_scale":
        #         continue
        #     if param.grad is None:
        #         print("+++++++++++++++++++++++++++:", name)
        # print("forward_time", time.time() - start_time)

        return result, None

    @torch.no_grad()
    def output_shape(self):
        example_obs_dict = dict()
        obs_shape_meta = self.shape_meta['obs']
        for key, attr in obs_shape_meta.items():
            shape = tuple(attr['shape'])
            this_obs = torch.zeros(
                (1, attr['horizon']) + shape,
                dtype=self.dtype,
                device=self.device)
            example_obs_dict[key] = this_obs
        example_output, example_output2 = self.forward(example_obs_dict)
        assert len(example_output.shape) == 2
        assert example_output.shape[0] == 1
        shape1 = example_output.shape
        shape2 = example_output2.shape if example_output2 is not None else None
        print("output_shape:", shape1, shape2)

        return shape1, None


if __name__ == '__main__':
    timm_obs_encoder = ClipTimmObsEncoder(
        shape_meta=None,
        model_name='resnet18.a1_in1k',
        pretrained=False,
        global_pool='',
        transforms=None
    )
