import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast as autocast
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch_scatter import scatter
from torch.nn import AdaptiveAvgPool1d
from torch_geometric.utils import get_laplacian, to_dense_adj, dense_to_sparse
from src.model.modeling_llama import LlamaForCausalLM
from src.model.gnn import load_gnn_model
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from torch_geometric.nn import TransformerConv, SAGPooling

# BOS = '<|begin_of_text|>'
# EOS_USER = '<|eot_id|>'
# EOS = '<|end_of_text|>'

# BOS = '<s> <<SYS>>\nYou are an AI assistant that provides direct and concise answers to user questions. Please respond to each query with a clear, straightforward answer without any explanations or descriptions. If there are multiple answers, arrange them in order of confidence from highest to lowest and separate them with "|".\n<</SYS>>'
BOS = '<s>[INST]'
EOS_USER = '[/INST]'
EOS = '</s>'
# BOS = '<s>'
# EOS_USER = ''
# EOS = '</s>'

IGNORE_INDEX = -100


class GraphLLM(torch.nn.Module):

    def __init__(
        self,
        args,
        **kwargs
    ):
        super().__init__()
        self.args = args
        self.max_txt_len = args.max_txt_len
        self.max_new_tokens = args.max_new_tokens

        print('Loading LLAMA for graph_llm')
        kwargs = {
            "device_map": "auto",
        }

        self.tokenizer = AutoTokenizer.from_pretrained(args.llm_model_path)
        self.tokenizer.pad_token_id = 0
        self.tokenizer.padding_side = 'left'

        model = AutoModelForCausalLM.from_pretrained(
            args.llm_model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            **kwargs
        )
        if args.llm_frozen == True:
            print("Freezing LLAMA!")
            for name, param in model.named_parameters():
                param.requires_grad = False
        else:
            print("Training LLAMA with LORA!")
            # model = prepare_model_for_kbit_training(model)
            lora_r: int = 16
            lora_alpha: int = 32
            lora_dropout: float = 0.05
            lora_target_modules = [
                "q_proj",
                "v_proj",
            ]
            config = LoraConfig(
                r=lora_r,
                lora_alpha=lora_alpha,
                target_modules=lora_target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)

        self.model = model
        print('Finish loading LLAMA!')


        self.graph_encoder = load_gnn_model[args.gnn_model_name](
            in_channels=args.gnn_in_dim,
            out_channels=args.gnn_hidden_dim,
            hidden_channels=args.gnn_hidden_dim,
            num_layers=args.gnn_num_layers,
            dropout=args.gnn_dropout,
            num_heads=args.gnn_num_heads,
        ).to(self.model.device)

        # self.graph_encoder = load_gnn_model[args.gnn_model_name](
        #     in_channels=2048,
        #     out_channels=2048,
        #     hidden_channels=args.gnn_hidden_dim,
        #     num_layers=args.gnn_num_layers,
        #     dropout=args.gnn_dropout,
        #     num_heads=args.gnn_num_heads,
        # ).to(self.model.device)

        self.projector = nn.Sequential(
            nn.Linear(args.gnn_hidden_dim, 2048),
            nn.Sigmoid(),
            nn.Linear(2048, model.config.hidden_size),
        ).to(self.model.device)

        self.linear = nn.Linear(4096, 1024).to(self.model.device)
        # self.layernorm = nn.LayerNorm(4096).to(self.model.device)
        # self.pool = SAGPooling(4096, ratio=0.5, GNN=TransformerConv).to(self.model.device)


    @property
    def device(self):
        return list(self.parameters())[0].device

    def maybe_autocast(self, dtype=torch.float16):
        # if on cpu, don't use autocast
        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
        enable_autocast = self.device != torch.device("cpu")

        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    # def encode_graphs(self, samples, question_pooling_embed=None):
    #     graphs = samples['graph']
    #         # graphs.x = self.linear(graphs.x)

    #     # num_edges = graphs.edge_index.size(1)
    #     # graphs.edge_attr = torch.zeros((num_edges, self.args.gnn_hidden_dim), device=graphs.x.device)
    #     graphs = graphs.to(self.model.device)
    #     n_embeds, _ = self.graph_encoder(graphs.x, graphs.edge_index.long(), graphs.edge_attr)
    #     n_embeds = self.projector(n_embeds)

    #     g_embeds = scatter(n_embeds, graphs.batch, dim=0, reduce='mean')
    #     return g_embeds

    def encode_graphs(self, samples, question_pooling_embed=None):
        graphs = samples['graph']
        graphs = graphs.to(self.model.device)
        if not 'edge_attr' in graphs:
            num_edges = graphs.edge_index.size(1)
            graphs.edge_attr = torch.zeros((num_edges, self.args.gnn_hidden_dim), device=graphs.x.device)
        n_embeds, _ = self.graph_encoder(graphs.x, graphs.edge_index.long(), graphs.edge_attr, question_pooling_embed)
        n_embeds = self.projector(n_embeds)

        g_embeds = scatter(n_embeds, graphs.batch, dim=0, reduce='mean')
        return g_embeds
    
    def forward(self, samples):
        batch_size = len(samples['id'])
        # encode description, questions and labels
        questions = self.tokenizer(samples["question"], add_special_tokens=False)
        descriptions = self.tokenizer(samples["desc"], add_special_tokens=False)
        labels = self.tokenizer(samples["label"], add_special_tokens=False)

        # encode special tokens
        eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
        eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)
        bos_embeds = self.model.model.get_input_embeddings()(self.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(self.device))
        pad_embeds = self.model.model.get_input_embeddings()(torch.tensor(self.tokenizer.pad_token_id, device=self.device)).unsqueeze(0)

        #todo: 尝试将查询嵌入添加到graph encoder里
        question_pooling_embed = []
        for i in range(batch_size):
            embedded_input = torch.mean(self.model.model.get_input_embeddings()(torch.tensor(questions.input_ids[i]).to(self.model.device)), dim=0)
            question_pooling_embed.append(embedded_input)
        question_pooling_embed = torch.stack(question_pooling_embed)
        graphs = samples['graph'].to(self.model.device)
        batch_indices = graphs.batch
        expanded_embeddings = self.linear(question_pooling_embed[batch_indices].float())

        graph_embeds = self.encode_graphs(samples, expanded_embeddings)

        batch_inputs_embeds = []
        batch_attention_mask = []
        batch_label_input_ids = []
        for i in range(batch_size):
            # Add bos & eos token
            label_input_ids = labels.input_ids[i][:self.max_new_tokens] + eos_tokens.input_ids
            input_ids = descriptions.input_ids[i][:self.max_txt_len] + questions.input_ids[i] + eos_user_tokens.input_ids + label_input_ids
            inputs_embeds = self.model.model.get_input_embeddings()(torch.tensor(input_ids).to(self.model.device))
            inputs_embeds = torch.cat([bos_embeds, graph_embeds[i].unsqueeze(0), inputs_embeds], dim=0)
            batch_inputs_embeds.append(inputs_embeds)
            batch_attention_mask.append([1] * inputs_embeds.shape[0])
            label_input_ids = [IGNORE_INDEX] * (inputs_embeds.shape[0]-len(label_input_ids))+label_input_ids
            batch_label_input_ids.append(label_input_ids)

        # pad inputs_embeds
        max_length = max([x.shape[0] for x in batch_inputs_embeds])
        for i in range(batch_size):
            pad_length = max_length-batch_inputs_embeds[i].shape[0]
            batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])

            batch_attention_mask[i] = [0]*pad_length + batch_attention_mask[i]

            batch_label_input_ids[i] = [IGNORE_INDEX]*pad_length + batch_label_input_ids[i]

        inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device)
        attention_mask = torch.tensor(batch_attention_mask).to(self.model.device)
        label_input_ids = torch.tensor(batch_label_input_ids).to(self.model.device)
        with self.maybe_autocast():
            outputs = self.model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                labels=label_input_ids,
                # graph_embeds=graph_embeds
            )

        batch_inputs_embeds = []
        batch_attention_mask = []
        batch_label_input_ids = []
        for i in range(batch_size):
            # Add bos & eos token
            label_input_ids = labels.input_ids[i][:self.max_new_tokens] + eos_tokens.input_ids
            input_ids = questions.input_ids[i] + eos_user_tokens.input_ids + label_input_ids
            inputs_embeds = self.model.model.get_input_embeddings()(torch.tensor(input_ids).to(self.model.device))
            inputs_embeds = torch.cat([bos_embeds, graph_embeds[i].unsqueeze(0), inputs_embeds], dim=0)
            batch_inputs_embeds.append(inputs_embeds)
            batch_attention_mask.append([1] * inputs_embeds.shape[0])
            label_input_ids = [IGNORE_INDEX] * (inputs_embeds.shape[0]-len(label_input_ids))+label_input_ids
            batch_label_input_ids.append(label_input_ids)

        # pad inputs_embeds
        max_length = max([x.shape[0] for x in batch_inputs_embeds])
        for i in range(batch_size):
            pad_length = max_length-batch_inputs_embeds[i].shape[0]
            batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])

            batch_attention_mask[i] = [0]*pad_length + batch_attention_mask[i]

            batch_label_input_ids[i] = [IGNORE_INDEX]*pad_length + batch_label_input_ids[i]

        inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device)
        attention_mask = torch.tensor(batch_attention_mask).to(self.model.device)
        label_input_ids = torch.tensor(batch_label_input_ids).to(self.model.device)

        with self.maybe_autocast():
            outputs_task3 = self.model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                labels=label_input_ids,
                # graph_embeds=graph_embeds
            )

        a = 0.5
        # return (1-a) * outputs.loss + a * outputs_task3.loss
        return outputs.loss + outputs_task3.loss

    def inference(self, samples):
        batch_size = len(samples['id'])
        # encode description and questions
        questions = self.tokenizer(samples["question"], add_special_tokens=False)
        descriptions = self.tokenizer(samples["desc"], add_special_tokens=False)

        # encode special tokens
        eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)
        bos_embeds = self.model.model.get_input_embeddings()(self.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(self.device))
        pad_embeds = self.model.model.get_input_embeddings()(torch.tensor(self.tokenizer.pad_token_id, device=self.device)).unsqueeze(0)

        #todo: 尝试将查询嵌入添加到graph encoder里
        question_pooling_embed = []
        for i in range(batch_size):
            embedded_input = torch.mean(self.model.model.get_input_embeddings()(torch.tensor(questions.input_ids[i]).to(self.model.device)), dim=0)
            question_pooling_embed.append(embedded_input)
        question_pooling_embed = torch.stack(question_pooling_embed)
        graphs = samples['graph'].to(self.model.device)
        batch_indices = graphs.batch
        expanded_embeddings = self.linear(question_pooling_embed[batch_indices].float())

        graph_embeds = self.encode_graphs(samples, expanded_embeddings)
        # graph_embeds = self.encode_graphs(samples)

        batch_inputs_embeds = []
        batch_attention_mask = []
        for i in range(batch_size):
            # Add bos & eos token
            input_ids = descriptions.input_ids[i][:self.max_txt_len] + questions.input_ids[i] + eos_user_tokens.input_ids
            inputs_embeds = self.model.model.get_input_embeddings()(torch.tensor(input_ids).to(self.model.device))

            inputs_embeds = torch.cat([bos_embeds, graph_embeds[i].unsqueeze(0), inputs_embeds], dim=0)

            batch_inputs_embeds.append(inputs_embeds)
            batch_attention_mask.append([1] * inputs_embeds.shape[0])

        # pad inputs_embeds
        max_length = max([x.shape[0] for x in batch_inputs_embeds])
        for i in range(batch_size):
            pad_length = max_length-batch_inputs_embeds[i].shape[0]
            batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
            batch_attention_mask[i] = [0]*pad_length+batch_attention_mask[i]

        inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device)
        attention_mask = torch.tensor(batch_attention_mask).to(self.model.device)

        with self.maybe_autocast():
            outputs = self.model.generate(
                inputs_embeds=inputs_embeds,
                max_new_tokens=32,
                attention_mask=attention_mask,
                # num_beams=4,
                # early_stopping=True,
                # do_sample=True,
                use_cache=True,  # IMPORTANT!
                pad_token_id=self.tokenizer.eos_token_id,
                # graph_embeds=graph_embeds
            )
        pred = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        return {'id': [int(t) for t in samples['id']],
                'pred': pred,
                'label': samples['label'],
                'question': samples['question'],
                'desc': samples['desc']
                }

    def print_trainable_params(self):
        trainable_params = 0
        all_param = 0

        for _, param in self.named_parameters():
            num_params = param.numel()

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params

        return trainable_params, all_param

    def encode_dialog_prompt(self, dialog):
        tokens = []
        tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
        for message in dialog:
            tokens.extend(self.encode_message(message))
        # Add the start of an assistant message for the model to complete.
        tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
        return tokens