import numpy as np
from einops import rearrange

import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

from torch_scatter import scatter_mean

from model.vq import l2norm
import contextlib
from torch.cuda.amp import autocast as autocast
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch_scatter import scatter
# from src.model.gnn import load_gnn_model
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
BOS = '<s>[INST]'
EOS_USER = '[/INST]'
EOS = '</s>'

IGNORE_INDEX = -100

# helper functions


def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


def distance_metric(a, b, use_cosine_sim=True):
    # a shape: [n, d]
    # b shape: [m, d]

    if use_cosine_sim:
        a = l2norm(a)
        b = l2norm(b)

        cross_term = torch.mm(a, b.t())
        logits = 2 - 2 * cross_term
    else:
        a_sq = torch.sum(a ** 2, dim=1).unsqueeze(1)  # Shape: [n, 1]
        b_sq = torch.sum(b ** 2, dim=1).unsqueeze(0)  # Shape: [1, m]
        cross_term = torch.mm(a, b.t())  # Shape: [n, m]

        logits = a_sq + b_sq - 2 * cross_term

    return -logits


def efficient_compute_class_prototypes(embeddings, classes, num_classes_in_total, return_head_first):
    # Embeddings (z) shape: [n, d] or [n, h, d] or [r, n, h, d]
    # Classes shape: [n] or [r, n]
    # return_head_first: if True, the first dimension of the output will be the heads, otherwise it will be the classes

    embeddings = l2norm(embeddings)

    ndim = embeddings.ndim
    assert ndim in [2, 3, 4]

    if ndim == 4:
        num_runs = embeddings.shape[0]
    else:
        num_runs = 1

    # Rearrange the embeddings as [run, head, num_nodes, dim]
    # classes as [run, num_nodes]
    if ndim == 2:
        embeddings = rearrange(embeddings, "n d -> 1 1 n d")
        classes = rearrange(classes, "n -> 1 n")
    elif ndim == 3:
        embeddings = rearrange(embeddings, "n h d -> 1 h n d")
        classes = rearrange(classes, "n -> 1 n")
    elif ndim == 4:
        embeddings = rearrange(embeddings, "r n h d -> r h n d")

    # Compute the class prototypes for each run.
    class_prototypes = []
    for i in range(num_runs):
        class_prototypes.append(scatter_mean(embeddings[i], classes[i], dim=1, dim_size=num_classes_in_total))
    class_prototypes = torch.stack(class_prototypes, dim=0)  # [r, h, c, d]

    if ndim == 2:
        class_prototypes = rearrange(class_prototypes, "1 1 c d -> c d")
    elif ndim == 3:
        class_prototypes = rearrange(class_prototypes, "1 h c d -> h c d")

    if return_head_first or ndim <= 2:
        return class_prototypes
    else:
        if ndim == 3:
            return rearrange(class_prototypes, "h c d -> c h d")
        elif ndim == 4:
            return rearrange(class_prototypes, "r h c d -> r c h d")


def compute_multitask_loss(pred, y):
    criterion = nn.BCEWithLogitsLoss(reduction="none")

    y[y == 0] = -1
    is_valid = y ** 2 > 0
    loss = 0.0

    for idx in range(y.shape[1]):
        exist_y = y[is_valid[:, idx], idx]
        exist_pred = pred[is_valid[:, idx], idx]
        task_loss = criterion(exist_pred, (exist_y + 1) / 2)
        loss += torch.sum(task_loss)

    return loss / torch.sum(is_valid)


class TaskModel(nn.Module):
    def __init__(self, encoder, vq, num_classes, params):
        super().__init__()
        self.encoder = encoder
        self.vq = vq

        self.num_classes = num_classes

        self.expert_num, self.num_heads, codebook_size, self.code_dim = vq.codebook.shape

        self.use_z_in_predict = params["use_z_in_predict"]
        self.use_cosine_sim = params["use_cosine_sim"]

        self.decoder = nn.Linear(self.code_dim, num_classes)

    def encode(self, x, edge_index, edge_attr):
        return self.encoder(x, edge_index, edge_attr)

    def encode_graph(self, x, edge_index, edge_attr=None, batch=None, pool="mean"):
        z = self.encoder(x, edge_index, edge_attr)
        if pool == "mean":
            z = global_mean_pool(z, batch)
        elif pool == "sum":
            z = global_add_pool(z, batch)
        elif pool == "max":
            z = global_max_pool(z, batch)
        return z

    def get_class_prototypes(self, z, y, num_classes_in_total):
        if isinstance(y, dict):
            # This works for graph classification with multiple binary tasks
            n_task = len(y)
            flat_y = np.array([])

            for task, labels in y.items():
                flat_y = np.concatenate((flat_y, task * 2 + labels), axis=0)
            flat_y = torch.tensor(flat_y, dtype=torch.long, device=z.device)

            proto_emb = efficient_compute_class_prototypes(z, flat_y, num_classes_in_total * 2, return_head_first=False)
            return proto_emb.resize(n_task, 2, self.num_heads, self.code_dim)

        else:
            # This works for node and link classification
            return efficient_compute_class_prototypes(z, y, num_classes_in_total, return_head_first=False)

    def compute_proto_loss(self, query_emb, proto_emb, y, task="single"):
        # query_emb in [n, d] or [n, h, d]
        # proto_emb in [c, d] or [c, h, d]
        ndim_query = query_emb.ndim
        ndim_proto = proto_emb.ndim

        assert ndim_query in [2, 3]
        assert ndim_proto in [2, 3, 4]

        if ndim_query == 2:
            query_emb = rearrange(query_emb, "n d -> n 1 d")
        if ndim_proto == 2:
            proto_emb = rearrange(proto_emb, "c d -> c 1 d")
        if ndim_proto == 4 and task == 'multi':
            # This works for multitask learning (binary)
            n_task = proto_emb.shape[0]
            proto_emb = rearrange(proto_emb, "t c h d -> (t c) h d")

        query_emb = rearrange(query_emb, "n h d -> h n d")
        proto_emb = rearrange(proto_emb, "c h d -> h c d")

        query_heads = query_emb.shape[0]
        proto_heads = proto_emb.shape[0]
        num_heads = max(query_heads, proto_heads)

        proto_loss = 0
        for h in range(num_heads):
            query_emb_iter = query_emb[0] if query_heads == 1 else query_emb[h]
            proto_emb_iter = proto_emb[0] if proto_heads == 1 else proto_emb[h]

            logits = distance_metric(query_emb_iter, proto_emb_iter, self.use_cosine_sim)

            if task == "single":
                proto_loss += F.cross_entropy(logits, y)
            elif task == "multi":
                logits = rearrange(logits, "n (t c) -> n t c", t=n_task, c=2)
                logits = logits[:, :, 0] - logits[:, :, 1]  # The 0-th is positive, the 1-th is negative
                proto_loss += compute_multitask_loss(logits, y)
            else:
                raise ValueError('task must be either "single" or "multi"')
        proto_loss /= num_heads

        return proto_loss

    def compute_proto_reg(self, proto_emb):
        # proto_emb in [c, d] or [c, h, d]
        ndim = proto_emb.ndim
        if ndim == 2:
            return 0
        if ndim == 4:
            proto_emb = rearrange(proto_emb, "t c h d -> (t c) h d")

        proto_emb = rearrange(proto_emb, "c h d -> h c d")
        proto_mean = proto_emb.mean(0)

        num_heads = proto_emb.shape[0]

        proto_reg = 0
        for h in range(num_heads):
            proto_reg += F.kl_div(
                proto_emb[h].log_softmax(dim=-1),
                proto_mean.softmax(dim=-1),
                reduction="batchmean",
            )
        proto_reg /= num_heads

        return proto_reg

    def compute_activation_loss(self, z, y, task="single"):
        if task == "single":
            pred = self.get_lin_logits(z).mean(1)
            return F.cross_entropy(pred, y)
        elif task == "multi":
            pred = self.get_lin_logits(z).mean(1)
            return compute_multitask_loss(pred, y)
        else:
            raise ValueError('task must be either "single" or "multi"')

    def get_lin_logits(self, z):
        quantize, _ = self.vq(z)
        pred = self.decoder(quantize)
        return pred.reshape(-1, 1, self.num_classes)

    def get_proto_logits(self, query_emb, proto_emb, task='single'):
        # query_emb in [n, d] or [n, h, d]
        # proto_emb in [c, d] or [c, h, d]
        ndim_query = query_emb.ndim
        ndim_proto = proto_emb.ndim

        assert ndim_query in [2, 3]
        assert ndim_proto in [2, 3, 4]

        if ndim_query == 2:
            query_emb = rearrange(query_emb, "n d -> n 1 d")
        if ndim_proto == 2:
            proto_emb = rearrange(proto_emb, "c d -> c 1 d")
        if ndim_proto == 4:
            n_task = proto_emb.shape[0]
            proto_emb = rearrange(proto_emb, "t c h d -> (t c) h d")

        query_emb = rearrange(query_emb, "n h d -> h n d")
        proto_emb = rearrange(proto_emb, "c h d -> h c d")

        query_heads = query_emb.shape[0]
        proto_heads = proto_emb.shape[0]
        num_heads = max(query_heads, proto_heads)

        total_logits = 0
        for h in range(num_heads):
            query_emb_iter = query_emb[0] if query_heads == 1 else query_emb[h]
            proto_emb_iter = proto_emb[0] if proto_heads == 1 else proto_emb[h]

            logits = distance_metric(query_emb_iter, proto_emb_iter, self.use_cosine_sim)
            if task == 'multi':
                logits = rearrange(logits, "n (t c) -> n t c", t=n_task, c=2)
                logits = logits[:, :, 0] - logits[:, :, 1]  # The 0-th is positive, the 1-th is negative
            total_logits += logits

        total_logits = total_logits / num_heads

        return total_logits
    

class TaskModelWithoutVQ(nn.Module):
    def __init__(self, encoder, vq, num_classes, params):
        super().__init__()

        assert vq is None

        self.encoder = encoder
        self.num_classes = num_classes
        
        self.trade_off = 1

        self.lin = nn.Linear(params['hidden_dim'], num_classes)

    def encode(self, x, edge_index, edge_attr=None):
        z, _ = self.encoder(x, edge_index, edge_attr)
        return z

    def encode_graph(self, x, edge_index, edge_attr=None, batch=None, pool="mean"):
        z = self.encoder(x, edge_index, edge_attr)
        if pool == "mean":
            z = global_mean_pool(z, batch)
        elif pool == "sum":
            z = global_add_pool(z, batch)
        elif pool == "max":
            z = global_max_pool(z, batch)
        return z

    def compute_activation_loss(self, z, y, task="single"):
        pred = self.lin(z)
        if task == "single":
            return F.cross_entropy(pred, y)
        elif task == "multi":
            return compute_multitask_loss(pred, y)
        else:
            raise ValueError('task must be either "single" or "multi"')
        
    def get_lin_logits(self, z):
        return self.lin(z).reshape(-1, 1, self.num_classes)
    

class GraphLLM(torch.nn.Module):
    def __init__(
        self,
        args,  # args = params
        **kwargs
    ):
        super().__init__()
        self.max_txt_len = args['max_txt_len']
        self.max_new_tokens = args['max_new_tokens']

        print('Loading LLAMA')
        kwargs = {
            "max_memory": {0: '80GiB'},
            # "max_memory": {0: '80GiB', 1: '80GiB'},
            "device_map": "auto",
            "revision": "main",
        }

        self.tokenizer = AutoTokenizer.from_pretrained(args['llm_model_path'], use_fast=False, revision=kwargs["revision"])
        self.tokenizer.pad_token_id = 0
        self.tokenizer.padding_side = 'left'

        model = AutoModelForCausalLM.from_pretrained(  # llama-7b
            args['llm_model_path'],
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            **kwargs
        )

        if args['llm_frozen'] == 'True':
            print("Freezing LLAMA!")
            for name, param in model.named_parameters():
                param.requires_grad = False
        else:
            print("Training LLAMA with LORA!")
            model = prepare_model_for_kbit_training(model)
            lora_r: int = 8
            lora_alpha: int = 16
            lora_dropout: float = 0.05
            lora_target_modules = [
                "q_proj",
                "v_proj",
            ]
            config = LoraConfig(
                r=lora_r,
                lora_alpha=lora_alpha,
                target_modules=lora_target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)

        self.model = model
        print('Finish loading LLAMA!')

        # self.graph_encoder = load_gnn_model[args.gnn_model_name](
        #     in_channels=args.gnn_in_dim,
        #     out_channels=args.gnn_hidden_dim,
        #     hidden_channels=args.gnn_hidden_dim,
        #     num_layers=args.gnn_num_layers,
        #     dropout=args.gnn_dropout,
        #     num_heads=args.gnn_num_heads,
        # ).to(self.model.device)

        self.projector = nn.Sequential(
            nn.Linear(args['hidden_dim'], 2048),
            nn.Sigmoid(),
            nn.Linear(2048, 4096),
        ).to(self.model.device)

        self.word_embedding = self.model.model.get_input_embeddings()

    @property
    def device(self):
        return list(self.parameters())[0].device

    def maybe_autocast(self, dtype=torch.bfloat16):
        # if on cpu, don't use autocast
        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
        enable_autocast = self.device != torch.device("cpu")

        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    def encode_graphs(self, samples):
        graphs = samples['graph']
        graphs = graphs.to(self.model.device)
        n_embeds, _ = self.graph_encoder(graphs.x, graphs.edge_index.long(), graphs.edge_attr)

        # mean pooling
        g_embeds = scatter(n_embeds, graphs.batch, dim=0, reduce='mean')

        return g_embeds

    def forward(self, z, samples):

        # encode description, questions and labels
        questions = self.tokenizer(samples.question, add_special_tokens=False)
        descriptions = self.tokenizer(samples.desc, add_special_tokens=False)
        # if samples.y.ndim != 1:
        #     samples.y = samples.y.
        labels = self.tokenizer(samples.y, add_special_tokens=False)

        # encode special tokens
        eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
        eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)
        bos_embeds = self.word_embedding(self.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(self.device))
        pad_embeds = self.word_embedding(torch.tensor(self.tokenizer.pad_token_id).to(self.device)).unsqueeze(0)

        # encode graphs
        graph_embeds = z
        graph_embeds = self.projector(graph_embeds)

        batch_size = len(samples)
        batch_inputs_embeds = []
        batch_attention_mask = []
        batch_label_input_ids = []
        for i in range(batch_size):
            # Add bos & eos token
            label_input_ids = labels.input_ids[i][:self.max_new_tokens] + eos_tokens.input_ids
            input_ids = descriptions.input_ids[i][:self.max_txt_len] + questions.input_ids[i] + eos_user_tokens.input_ids + label_input_ids
            inputs_embeds = self.word_embedding(torch.tensor(input_ids).to(self.model.device))
            inputs_embeds = torch.cat([bos_embeds, graph_embeds[i].unsqueeze(0), inputs_embeds], dim=0)

            batch_inputs_embeds.append(inputs_embeds)
            batch_attention_mask.append([1] * inputs_embeds.shape[0])
            label_input_ids = [IGNORE_INDEX] * (inputs_embeds.shape[0]-len(label_input_ids))+label_input_ids
            batch_label_input_ids.append(label_input_ids)

        # pad inputs_embeds
        max_length = max([x.shape[0] for x in batch_inputs_embeds])
        for i in range(batch_size):
            pad_length = max_length-batch_inputs_embeds[i].shape[0]
            batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
            batch_attention_mask[i] = [0]*pad_length+batch_attention_mask[i]
            batch_label_input_ids[i] = [IGNORE_INDEX] * pad_length+batch_label_input_ids[i]

        inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device)
        attention_mask = torch.tensor(batch_attention_mask).to(self.model.device)
        label_input_ids = torch.tensor(batch_label_input_ids).to(self.model.device)

        with self.maybe_autocast():
            outputs = self.model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                labels=label_input_ids,
            )

        return outputs.loss

    def inference(self, z,samples):

        # encode description and questions
        questions = self.tokenizer(samples.question, add_special_tokens=False)
        descriptions = self.tokenizer(samples.desc, add_special_tokens=False)

        # encode special tokens
        eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)
        bos_embeds = self.word_embedding(self.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(self.device))
        pad_embeds = self.word_embedding(torch.tensor(self.tokenizer.pad_token_id).to(self.device)).unsqueeze(0)

        # encode graphs
        graph_embeds = z
        graph_embeds = self.projector(graph_embeds)

        batch_size = len(samples)
        batch_inputs_embeds = []
        batch_attention_mask = []
        for i in range(batch_size):
            # Add bos & eos token
            input_ids = descriptions.input_ids[i][:self.max_txt_len] + questions.input_ids[i] + eos_user_tokens.input_ids
            inputs_embeds = self.word_embedding(torch.tensor(input_ids).to(self.model.device))
            inputs_embeds = torch.cat([bos_embeds, graph_embeds[i].unsqueeze(0), inputs_embeds], dim=0)
            batch_inputs_embeds.append(inputs_embeds)
            batch_attention_mask.append([1] * inputs_embeds.shape[0])

        # pad inputs_embeds
        max_length = max([x.shape[0] for x in batch_inputs_embeds])
        for i in range(batch_size):
            pad_length = max_length-batch_inputs_embeds[i].shape[0]
            batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
            batch_attention_mask[i] = [0]*pad_length+batch_attention_mask[i]

        inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device)
        attention_mask = torch.tensor(batch_attention_mask).to(self.model.device)

        with self.maybe_autocast():
            outputs = self.model.generate(
                inputs_embeds=inputs_embeds,
                max_new_tokens=self.max_new_tokens,
                attention_mask=attention_mask,
                # do_sample=True,
                use_cache=True  # IMPORTANT!
            )
        pred = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

        return {#'id': samples.x,
                'pred': pred,
                'label': samples.y,
                'question': samples.question,
                'desc': samples.desc, }

    def print_trainable_params(self):
        trainable_params = 0
        all_param = 0

        for _, param in self.named_parameters():
            num_params = param.numel()

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params

        return trainable_params, all_param


class GQATaskModel(nn.Module):
    def __init__(self,encoder,vq,llm,params):
        super().__init__()
        self.encoder = encoder
        self.vq = vq
        self.llm = llm # Graph_LLM
        self.params = params
    
    def encode_graph(self,x,edge_index,edge_attr,batch=None,pool="mean",field=None):
        h , _ = self.encoder(x,edge_index,edge_attr)
        if pool == "mean":
            h = global_mean_pool(h, batch)
        elif pool == "sum":
            h = global_add_pool(h, batch)
        elif pool == "max":
            h = global_max_pool(h, batch)
        z , _ = self.vq(h, field, 'ft')
        return z
    
    def llm_train(self,z,batch):
        return self.llm(z,batch)
    
    def llm_eval(self,z,batch):
        return self.llm.inference(z,batch)