from math import sqrt
import torch
import torch.nn as nn
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer
import transformers
from models.StandardNorm import Normalize
transformers.logging.set_verbosity_error()
from rxnfp.transformer_fingerprints import (
    RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints)

from models.Projection_reprogramming import ReprogrammingLayer
from models.Projection_perceiver import PerceiverLayer

class FlattenHead(nn.Module):
    def __init__(self, nf, target_window, head_dropout=0):
        super().__init__()
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):
        x = self.linear(x)
        x = self.dropout(x)
        return x


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.task_name = configs.task_name
        self.tgt_vocab_size = configs.tgt_vocab_size
        self.d_ff = configs.d_ff
        self.d_llm = 4096
        self.llama_pretrained_path=configs.llama_pretrained_path
        self.configs=configs

        self.llama_config = LlamaConfig.from_pretrained(self.llama_pretrained_path)
        self.llama_config.num_hidden_layers = configs.llm_layers
        self.llama_config.output_attentions = True
        self.llama_config.output_hidden_states = True
        try:
            self.llama = LlamaModel.from_pretrained(
            self.llama_pretrained_path,
            trust_remote_code=True,
            local_files_only=True,
            config=self.llama_config,
        )
        except EnvironmentError:
            print("Local model files not found. Attempting to download...")
            self.llama = LlamaModel.from_pretrained(
            self.llama_pretrained_path,
            trust_remote_code=True,
            local_files_only=False,
            config=self.llama_config,
        )
        try:
            self.tokenizer = LlamaTokenizer.from_pretrained(
                self.llama_pretrained_path,
                trust_remote_code=True,
                local_files_only=True
            )
        except EnvironmentError:
            print("Local tokenizer files not found. Atempting to download them..")
            self.tokenizer = LlamaTokenizer.from_pretrained(
                self.llama_pretrained_path,
                trust_remote_code=True,
                local_files_only=False
            )

        if self.tokenizer.eos_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        else:
            pad_token = '[PAD]'
            self.tokenizer.add_special_tokens({'pad_token': pad_token})
            self.tokenizer.pad_token = pad_token

        for param in self.llama.parameters():
            param.requires_grad = False

        self.dropout = nn.Dropout(configs.dropout)
        self.word_embeddings = self.llama.get_input_embeddings().weight
        self.vocab_size = self.word_embeddings.shape[0]
        self.num_tokens = 1000
        self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)
        
        self.source_len=512
        self.target_len=128
        self.mapping_layer_chemistry = nn.Linear(self.source_len, self.target_len)

        if self.configs.use_graph:
            self.mapping_layer_graph = nn.Linear(self.vocab_size, self.num_tokens)
            self.perceiver_linear_graph =  nn.Linear(384, self.d_llm)
            self.perceiver_projection_graph = PerceiverLayer(self.d_llm, self.d_llm, 3)

        if self.configs.use_fp:
            self.mapping_layer_fp = nn.Linear(self.vocab_size, self.num_tokens)
            self.perceiver_linear_fp =  nn.Linear(256, self.d_llm)
            self.perceiver_projection_fp = PerceiverLayer(self.d_llm, self.d_llm, 3)

        self.source_len=1000+128
        self.target_len=128
        self.mlp_linear = nn.Linear(configs.d_model, self.d_llm)
        self.mlp_projection = nn.Linear(self.source_len, self.target_len)

        self.merge_len=64
        self.perceiver_linear = nn.Linear(configs.d_model, self.d_llm)
        self.perceiver_projection= PerceiverLayer(self.d_llm, self.d_llm, self.merge_len)

        self.reprogramming_layer = ReprogrammingLayer(configs.d_model, configs.n_heads, self.d_ff, self.d_llm)

        if configs.data == 'USPTO-500MT':
            self.output_len = 6
        elif configs.data == 'USPTO-Condition':
            self.output_len = 5
        self.head_nf = self.d_ff
        self.output_projection = FlattenHead(self.head_nf, self.tgt_vocab_size,
                                                 head_dropout=configs.dropout)

    def create_fp(self, inputs):
        model, tokenizer = get_default_model_and_tokenizer()
        rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)
        fps = []
        for index in range(inputs['labels'].shape[0]):
            fp = rxnfp_generator.convert(inputs['rxn_text'][index])
            fps.append(fp)
        return fps

    def forward(self, inputs):
        smiles_emb=inputs["input_emb"]
        rxn_text_flag=self.configs.rxn_text_flag
        rxn_source_flag=self.configs.rxn_source_flag
        use_graph = self.configs.use_graph
        prompt = []
        for index in range(inputs['labels'].shape[0]):
            if rxn_text_flag:
                if self.configs.data == 'USPTO-500MT':
                    prompt_ = (
                        f"<|start_prompt|>"
                        "given a reaction SMILES and the embedding of the reaction, please predict the optimal reagents of the reaction; "
                        "The reaction descpription text is as follows: "
                        f"{inputs['rxn_text'][index]}<|<end_prompt>|>"
                    )
                else:
                    prompt_ = (
                        f"<|start_prompt|>"
                        "given the text description of a reaction and the embedding of the SMILES representation of this reaction, please predict the optimal catalysts, solvents and reagents of the reaction; "
                        "The reaction descpription text is as follows: "
                        f"{inputs['rxn_text'][index]}<|<end_prompt>|>"
                    )      

            elif use_graph and rxn_text_flag == False:
                prompt_ = (
                    f"<|start_prompt|>"
                    "given the text description of a reaction, the embedding of the SMILES representation of this reaction and the graph embedding of the reaction, please predict the optimal catalysts, solvents and reagents of the reaction; "
                    "The reaction descpription text is as follows: "
                    f"{inputs['paragraph_text'][index]}<|<end_prompt>|>"
                )
            else:
                prompt_ = (
                    f"<|start_prompt|>"
                    f"given the text description of a reaction and the embedding of the SMILES representation of this reaction, please predict the optimal catalysts, solvents and reagents of the reaction; "
                    f"The reaction descpription text is as follows: "
                    f"{inputs['paragraph_text'][index]}<|<end_prompt>|>"
                )
            prompt.append(prompt_)

        prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=256).input_ids
        rxn_text = self.tokenizer(inputs['rxn_text'], return_tensors="pt", padding=True, truncation=True, max_length=256).input_ids
        prompt_embeddings = self.llama.get_input_embeddings()(prompt.to(smiles_emb.device))
        rxn_text_embeddings = self.llama.get_input_embeddings()(rxn_text.to(smiles_emb.device))

        if rxn_source_flag:
            source_embeddings = rxn_text_embeddings #([B, L, 4096])
        else:
            source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)

        target_embeddings = self.mapping_layer_chemistry(smiles_emb.permute(0, 2, 1)).permute(0, 2, 1)
        projection_type=self.configs.projection_type
        if projection_type=="reprogramming":
            enc_out = self.reprogramming_layer(target_embeddings, source_embeddings, source_embeddings)
        elif projection_type=="perceiver":
            if len(source_embeddings.shape)==3:
                enc_out = self.perceiver_projection(torch.cat([source_embeddings, self.perceiver_linear(target_embeddings)], dim=1))
            elif len(source_embeddings.shape)==2:
                enc_out = self.perceiver_projection(torch.cat([source_embeddings.repeat(target_embeddings.size()[0],1,1), self.perceiver_linear(target_embeddings)], dim=1))
        elif projection_type=="mlp":
            if len(source_embeddings.shape)==3:
                enc_out = self.mlp_projection(torch.cat([source_embeddings, self.mlp_linear(target_embeddings)], dim=1).permute(0,2,1)).permute(0,2,1)
            elif len(source_embeddings.shape)==2:
                enc_out = self.mlp_projection(torch.cat([source_embeddings.repeat(target_embeddings.size()[0],1,1), self.mlp_linear(target_embeddings)], dim=1).permute(0,2,1)).permute(0,2,1)
        else:
            raise Exception('projection method {} has not been implemented'.format(self.configs.projection_type))

        if self.configs.only_parrot:
            llama_enc_out = enc_out
        elif self.configs.only_corpus:
            llama_enc_out = prompt_embeddings
        elif self.configs.use_graph:
            source_embeddings_graph = self.mapping_layer_graph(self.word_embeddings.permute(1, 0)).permute(1, 0)
            target_embeddings = inputs["graph"].unsqueeze(1)
            graph_enc_out = self.perceiver_projection_graph(torch.cat([source_embeddings_graph.repeat(target_embeddings.size()[0],1,1), self.perceiver_linear_graph(target_embeddings)], dim=1))
            llama_enc_out = torch.cat([prompt_embeddings, enc_out, graph_enc_out], dim=1)
        elif self.configs.use_fp:
            fps = self.create_fp(inputs)
            source_embeddings_fp = self.mapping_layer_fp(self.word_embeddings.permute(1, 0)).permute(1, 0)
            target_embeddings = torch.Tensor(fps).unsqueeze(1).to(torch.bfloat16).to(smiles_emb.device)
            fp_enc_out = self.perceiver_projection_fp(torch.cat([source_embeddings_fp.repeat(target_embeddings.size()[0],1,1), self.perceiver_linear_fp(target_embeddings)], dim=1))
            llama_enc_out = torch.cat([prompt_embeddings, fp_enc_out], dim=1)
        else:
            llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)

        dec_out = self.llama(inputs_embeds=llama_enc_out).last_hidden_state
        dec_out = dec_out[:, :, :self.d_ff]
        dec_out = self.output_projection(dec_out[:, -self.output_len:, :])
        logits = dec_out
    
        return logits

