import numpy as np
import torch.nn.functional as F
from torch import nn
from sgEncoderTraining.sgEncoder.module import GraphTripleConv, GraphTripleConvNet, Attention
from configs.configs_laion import CLIPGraphCfg
from sgEncoderTraining.global_var import *

def _encode_prompt_with_clip(
    text_encoder,
    tokenizer,
    prompt: str,
    device=None,
    num_images_per_prompt: int = 1,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    text_inputs = tokenizer(
        prompt,
        padding=False,
        max_length=77,
        truncation=True,
        return_tensors="pt",
    )

    text_input_ids = text_inputs.input_ids
    prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)

    pooled_prompt_embeds = prompt_embeds[0]
    prompt_embeds = prompt_embeds.hidden_states[-2]
    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=text_encoder.device)

    _, seq_len, _ = prompt_embeds.shape
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds, pooled_prompt_embeds

def _encode_prompt_with_t5(
    text_encoder,
    tokenizer,
    max_sequence_length,
    prompt=None,
    num_images_per_prompt=1,
    device=None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    text_inputs = tokenizer(
        prompt,
        padding=False,
        max_length=max_sequence_length,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]

    dtype = text_encoder.dtype
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=text_encoder.device)

    _, seq_len, _ = prompt_embeds.shape

    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds


class EMAParameter(nn.Module):
    def __init__(self, init_val=0.01, min_val=-1, max_val=1, ema_decay=0.999):
        super().__init__()

        self.raw_param = nn.Parameter(torch.tensor(init_val))
        self.min_val = min_val
        self.max_val = max_val

    def forward(self):
        return torch.clamp(self.raw_param, self.min_val, self.max_val)

    def get_raw(self):
        return torch.clamp(self.raw_param, self.min_val, self.max_val)


class sgEncoder(nn.Module):
    def __init__(self,
                 graph_cfg: CLIPGraphCfg,
                 text_encoders: list,
                 tokenizers: list,
                 embed_dim=512,
                 max_sample_per_img: int = 15,
                 clip_dim=2048,
                 t5_dim=4096,
                 ):
        super().__init__()
        if isinstance(graph_cfg, dict):
            graph_cfg = CLIPGraphCfg(**graph_cfg)

        self.max_sample_per_img = max_sample_per_img

        self.graph_conv = GraphTripleConv(embed_dim, output_dim=embed_dim, hidden_dim=graph_cfg.width, pooling='avg', mlp_normalization='none')
        self.graph_net = GraphTripleConvNet(embed_dim, num_layers=graph_cfg.layers, hidden_dim=graph_cfg.width, pooling='avg', mlp_normalization='none')

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.alpha_obj = EMAParameter(
            init_val=1e-2,   
            min_val=-1,  
            max_val=1,  
            ema_decay=0.999  
        )

        self.tokenizers = tokenizers

        self.text_encoders = text_encoders

        self.mode = 'clip'

        self.projection_mlp_1 = nn.Linear(clip_dim, embed_dim,bias=False)
        self.projection_mlp_2 = nn.Linear(embed_dim, clip_dim,bias=False)

        self.projection_mlp_1_t5 = nn.Linear(t5_dim, embed_dim, bias=False)
        self.projection_mlp_2_t5 = nn.Linear(embed_dim, t5_dim, bias=False)

    def initialize_parameters(self):
        nn.init.constant_(self.logit_scale, np.log(1 / 0.07))

        if hasattr(self.graph_conv, 'init_parameters'):
            self.graph_conv.init_parameters()
        if hasattr(self.graph_net, 'init_parameters'):
            self.graph_net.init_parameters()


    def get_triple_embeddings_clip(self, triples, isolated_items):

        pooled_embeddings = []

        obj_embeddings = []
        pred_embeddings = []
        edges = []

        index = 0

        triple_o_p_o_index_list=[]

        attri_embedding,_ = self.encode_prompt("-")
        attri_embedding = attri_embedding[:, 1:-1, :]

        device = attri_embedding.device

        for triple in triples:
            s_embedding, pooled_s_embedding = self.encode_prompt(str(triple['item1']))  #  (1, n, 768)
            p_embedding, pooled_p_embedding = self.encode_prompt(str(triple['relation']))
            o_embedding, pooled_o_embedding = self.encode_prompt(str(triple['item2']))

            if pooled_s_embedding is not None:  
                pooled_embeddings.append(pooled_s_embedding)

            #obj and attribute of s
            slice = s_embedding[:,1:-1,:]
            s_len = slice.shape[1]
            obj_embeddings.append(slice)

            attri_counts = s_len - 1

            for _ in range(attri_counts):
                pred_embeddings.append(attri_embedding)

            s_index_in_obj = index + s_len - 1
            for i in range(attri_counts):
                edges.append([i + index,s_index_in_obj])
            index += s_len

            if pooled_o_embedding is not None:
                pooled_embeddings.append(pooled_o_embedding)


            slice = o_embedding[:,1:-1,:]
            o_len = slice.shape[1]
            obj_embeddings.append(slice)

            attri_counts = o_len - 1

            for _ in range(attri_counts):
                pred_embeddings.append(attri_embedding)


            o_index_in_obj = index + o_len - 1
            for i in range(attri_counts):
                edges.append([i + index, o_index_in_obj])
            index += o_len

            p_start = sum(tensor.shape[1] for tensor in pred_embeddings)

            if pooled_p_embedding is not None:
                pooled_embeddings.append(pooled_p_embedding)

            p_self = p_embedding[:, 1:-1, :]

            size_p = p_self.shape[1]
            pred_embeddings.append(p_self)

            for i in range(size_p):
                edges.append([s_index_in_obj, o_index_in_obj])

            p_location = [p_start, size_p + p_start - 1]

            triple_o_p_o_index_list.append([s_index_in_obj, p_location, o_index_in_obj])

        isolated_embeddings = []
        for item in isolated_items:
            item_embedding, pooled_item_embedding = self.encode_prompt(item)  #  (1, n, 768)

            if pooled_item_embedding is not None:
                pooled_embeddings.append(pooled_item_embedding)

            isolated_embeddings.append(item_embedding[:, 1:-1, :]) #(1, n-2, 768)

        edges = torch.tensor(edges, device=device)

        return isolated_embeddings,obj_embeddings,pred_embeddings,edges,triple_o_p_o_index_list

    def get_triple_embeddings_t5(self, triples, isolated_items):

        obj_embeddings = []
        pred_embeddings = []
        edges = []

        index = 0

        triple_o_p_o_index_list=[]

        attri_embedding,_ = self.encode_prompt("-")
        attri_embedding = attri_embedding[:, 1:-1, :]

        device = attri_embedding.device

        for triple in triples:
            s_embedding, pooled_s_embedding = self.encode_prompt(str(triple['item1']))  #  (1, n, 768)
            p_embedding, pooled_p_embedding = self.encode_prompt(str(triple['relation']))
            o_embedding, pooled_o_embedding = self.encode_prompt(str(triple['item2']))

            #obj and attribute of s
            slice = s_embedding[:,0:-1,:]
            s_len = slice.shape[1]
            obj_embeddings.append(slice)

            attri_counts = s_len - 1

            for _ in range(attri_counts):
                pred_embeddings.append(attri_embedding)

            s_index_in_obj = index + s_len - 1
            for i in range(attri_counts):
                edges.append([i + index,s_index_in_obj])
            index += s_len


            slice = o_embedding[:,0:-1,:]
            o_len = slice.shape[1]
            obj_embeddings.append(slice)

            attri_counts = o_len - 1

            for _ in range(attri_counts):
                pred_embeddings.append(attri_embedding)


            o_index_in_obj = index + o_len - 1
            for i in range(attri_counts):
                edges.append([i + index, o_index_in_obj])
            index += o_len

            p_start = sum(tensor.shape[1] for tensor in pred_embeddings)


            p_self = p_embedding[:, 0:-1, :]

            size_p = p_self.shape[1]
            pred_embeddings.append(p_self)

            for i in range(size_p):
                edges.append([s_index_in_obj, o_index_in_obj])

            p_location = [p_start, size_p + p_start - 1]

            triple_o_p_o_index_list.append([s_index_in_obj, p_location, o_index_in_obj])

        isolated_embeddings = []
        for item in isolated_items:
            item_embedding, pooled_item_embedding = self.encode_prompt(item)  #  (1, n, 768)
            isolated_embeddings.append(item_embedding[:, 0:-1, :]) #(1, n-2, 768)

        edges = torch.tensor(edges, device=device)

        return isolated_embeddings,obj_embeddings,pred_embeddings,edges,triple_o_p_o_index_list


    def get_text_embeddings(self, triples, isolated_items):

        triple_str = ''

        for triple in triples:
            triple_str += str(triple['item1']) + ' ' + str(triple['relation']) + ' ' + str(triple['item2']) + ' '

        for item in isolated_items:
            triple_str += item + ' '

        triple_embedding, pooled_triple_embedding = self.encode_prompt(triple_str)  #  (1, n, 768)

        if self.mode =='clip':
            sos_embedding = triple_embedding[:, 0, :]  # (1, 768)
            eos_embedding = triple_embedding[:, -1, :]  #  (1, 768)
            triple_embedding = triple_embedding[:, 1:-1, :]  # (1, n-2, 768)
            return sos_embedding, triple_embedding, eos_embedding, pooled_triple_embedding
        else:
            eos_embedding = triple_embedding[:, -1, :]  # (1, 768)
            triple_embedding = triple_embedding[:, 0:-1, :]  # (1, n-2, 768)
            return None, triple_embedding, eos_embedding, pooled_triple_embedding

    def generate_total_embedding_for_unet(self,
                                          be_prompt_embeddings,
                                          sos_embeddings,
                                          eos_embeddings,
                                          pooled_embeddings,
                                          obj_embeddings,
                                          pred_embeddings,
                                          triple_o_p_o_index_list):

        prompt_embeddings = []

        o_begin_index = 0
        for indexs in triple_o_p_o_index_list:
            o_index,p_index,s_index = indexs
            prompt_embeddings.append(obj_embeddings[o_begin_index:o_index + 1])
            prompt_embeddings.append(pred_embeddings[p_index[0]:p_index[1] + 1])
            prompt_embeddings.append(obj_embeddings[o_index + 1:s_index + 1])

            o_begin_index = s_index + 1

        prompt_embeddings = torch.cat(prompt_embeddings, dim=0).unsqueeze(0)

        clamped_alpha_o = self.alpha_obj()

        c = prompt_embeddings.size(1)
        if c > 75:
            c = 75
            prompt_embeddings = prompt_embeddings[:, :c, :]

        be_prompt_embeddings[:, :c, :] += clamped_alpha_o * prompt_embeddings


        if self.mode=='clip':
            final_embeddings = torch.cat(
                (sos_embeddings.unsqueeze(0), be_prompt_embeddings, eos_embeddings.unsqueeze(0)), dim=1)
        else:
            final_embeddings = torch.cat(
                (be_prompt_embeddings, eos_embeddings.unsqueeze(0)), dim=1)

        target_length = 77
        current_length = final_embeddings.shape[1]

        if current_length < target_length:
            pad_length = target_length - current_length
            pad_embeddings = eos_embeddings.repeat(1, pad_length, 1)  #  (1, pad_length, 768)
            final_embeddings = torch.cat((final_embeddings, pad_embeddings), dim=1)
        elif current_length > target_length:
            final_embeddings = final_embeddings[:, :target_length, :]

        return final_embeddings, pooled_embeddings


    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=False):
        self.visual.set_grad_checkpointing(enable)

    def do_avg_pool(self,embeddings, adaptive_avg_pool):
        embeddings = embeddings.unsqueeze(0)  
        embeddings_pooled = adaptive_avg_pool(embeddings.permute(0, 2, 1)).permute(0, 2, 1)  # (1, 1, 768)
        return embeddings_pooled

    def tokenize_prompt(self,tokenizer, prompt):
        text_inputs = tokenizer(
            prompt,
            padding=False,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        return text_input_ids

    def encode_prompt(self, prompt):
        if self.mode == 'clip':
            clip_tokenizers = self.tokenizers[:2]
            clip_text_encoders = self.text_encoders[:2]

            clip_prompt_embeds_list = []
            clip_pooled_prompt_embeds_list = []

            for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
                # Detach the outputs to prevent gradient flow through text encoders
                with torch.no_grad():
                    prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
                        text_encoder=text_encoder,
                        tokenizer=tokenizer,
                        prompt=prompt,
                        device=device if device is not None else text_encoder.device,
                    )
                    # Clone and detach to create new tensors not connected to the graph
                    prompt_embeds = prompt_embeds.clone().detach()
                    pooled_prompt_embeds = pooled_prompt_embeds.clone().detach()

                clip_prompt_embeds_list.append(prompt_embeds)
                clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)

            clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
            pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)

            return clip_prompt_embeds, pooled_prompt_embeds

        elif self.mode == 't5':
            with torch.no_grad():
                t5_prompt_embed = _encode_prompt_with_t5(
                    self.text_encoders[-1],
                    self.tokenizers[-1],
                    max_sequence_length=77,
                    prompt=prompt,
                    device=device if device is not None else self.text_encoders[-1].device,
                )
                # Clone and detach
                t5_prompt_embed = t5_prompt_embed.clone().detach()

            return t5_prompt_embed, None


    def encode_graph_local_global(self, triples_list, isolated_items_list, item_list_list):

        final_embeddings_list = []
        pooled_embeddings_list = []

        if self.mode == 'clip':
            projection_mlp_1 = self.projection_mlp_1
            projection_mlp_2 = self.projection_mlp_2
        else:  # t5
            projection_mlp_1 = self.projection_mlp_1_t5
            projection_mlp_2 = self.projection_mlp_2_t5


        for (triples, isolated_items, item_list) in zip(triples_list, isolated_items_list, item_list_list):

            if self.mode == 'clip':
                (isolated_embeddings, obj_embeddings,
             pred_embeddings, edges, triple_o_p_o_index_list) = self.get_triple_embeddings_clip(triples, isolated_items)
            else :
                (isolated_embeddings, obj_embeddings,
                 pred_embeddings, edges, triple_o_p_o_index_list) = self.get_triple_embeddings_t5(triples,isolated_items)

            (be_sos_embeddings, be_prompt_embeddings, be_eos_embeddings,
             be_pooled_embeddings) = self.get_text_embeddings(triples, isolated_items)

            obj_embeddings = torch.cat(obj_embeddings, dim=1).squeeze(0)
            pred_embeddings = torch.cat(pred_embeddings, dim=1).squeeze(0)

            obj_vecs = projection_mlp_1(obj_embeddings)
            pred_vecs = projection_mlp_1(pred_embeddings)

            obj_vecs, pred_vecs = self.graph_conv(obj_vecs, pred_vecs, edges)

            if self.graph_net is not None:
                obj_vecs, pred_vecs = self.graph_net(obj_vecs, pred_vecs, edges)

            obj_vecs = F.normalize(obj_vecs, p=2, dim=1)
            pred_vecs = F.normalize(pred_vecs, p=2, dim=1)

            obj_vecs = projection_mlp_2(obj_vecs)
            pred_vecs = projection_mlp_2(pred_vecs)

            final_embeddings, pooled_embeddings = self.generate_total_embedding_for_unet(be_prompt_embeddings,
                                                                                         be_sos_embeddings,
                                                                                         be_eos_embeddings,
                                                                                         be_pooled_embeddings,
                                                                                         obj_vecs,
                                                                                         pred_vecs,
                                                                                         triple_o_p_o_index_list)

            final_embeddings_list.append(final_embeddings)
            if pooled_embeddings is not None:   
                pooled_embeddings_list.append(pooled_embeddings)


        final_embeddings_all = torch.cat(final_embeddings_list, dim=0)
        pooled_embeddings_all = torch.cat(pooled_embeddings_list, dim=0) if pooled_embeddings_list else None

        return final_embeddings_all, pooled_embeddings_all

    def forward(self, triples_list, isolated_items_list, item_list_list):

        self.mode = 'clip'
        clip_prompt_embeds, pooled_prompt_embeds = self.encode_graph_local_global(
            triples_list, isolated_items_list, item_list_list)

        self.mode = 't5'
        t5_prompt_embed, _ = self.encode_graph_local_global(
            triples_list, isolated_items_list, item_list_list)

        clip_prompt_embeds = torch.nn.functional.pad(
            clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
        )
        prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)

        return prompt_embeds, pooled_prompt_embeds

    def get_alpha(self):
        return self.alpha_obj().item()  

def convert_weights_to_fp16(model: nn.Module):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

        if isinstance(l, (nn.MultiheadAttention, Attention)):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.half()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.half()

    model.apply(_convert_weights_to_fp16)

