import numpy as np
import pandas as pd
import torch
import json
import os
import tqdm
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import add_remaining_self_loops
from torch_sparse import SparseTensor
from utils.data_utils import collator_graph_data, preprocess_v1, preprocess_llama_2, preprocess_llama_3
import pyximport

pyximport.install(setup_args={"include_dirs": np.get_include()})
from . import algos


DOMAIN_DICT = {
    'arxiv': 'citation', 'pubmed': 'citation', 'cora': 'citation', 'cora_simple': 'citation',
    'children': 'product', 'history': 'product', 'computer': 'product', 'sports': 'product', 'photo': 'product', 'products': 'product',
    'wikics': 'wiki',
    'reddit': 'reddit', 'instagram': 'instagram',
    'wn18rr': 'kg',
    'fb15k237': 'kg',
    'chemblpre': 'molecule',
    'hiv': 'molecule',
    'bace': 'molecule',
    'pcba': 'molecule',
    'lipo': 'molecule',
    'esol': 'molecule',
    'freesolv': 'molecule'
}

DOMAIN_DESC = {
    'citation': 'This is a citation network. Nodes represent academic papers and edges represent citation relationships.',
    'product': 'This is a co-purchase network. Nodes represent the products and edges represent two products are co-purchased together.',
    'wiki': 'This is a wikipedia graph focusing. Nodes represent wikipedia terms and edges represent two terms have hyperlink.',
    'reddit': 'This is a social network where each node represents a user, and edges denote whether two users have replied to each other.',
    'instagram': 'This is a social network where nodes represent users and edges represent following relationships.',
    'kg': 'This is a knowledge graph, node represents a entity and the edge between two entities represents a relation with a specific category.',
    'molecule': 'This is a molecular graph. Nodes represent atoms, and edges represent chemical bonds between atoms.',
    'normal': 'This is a graph.'
}

TASK_DESC = {
    'nc': "Determine this node's most likely category within the network's classification schema.",
    'lp': "Determine whether there is a specific relationship between these two nodes.",
    'lc': "Determine the edge's most likely category between these two nodes.",
    'gc': "Determine whether the molecule possesses specific physicochemical or bioactivity properties.",
    'gr': "Predict the continuous numerical value of a physicochemical or bioactivity property of the molecule.",
    'normal': "This is a graph task."
}


def get_instructions(data_path, frac=1, mode='train'):
    df = pd.read_json(data_path)
    # df['edge_num'] = df.apply(get_edge_num, axis=1)
    df['gpt'] = df['output']
    # df = df[df['edge_num'] != 0]
    df = df.reset_index(drop=True)
    if mode != 'train':
        if len(df) > 1000:
            df = df.sample(n=1000, ignore_index=True)
    else:
        if frac != 1:
            df = df.sample(frac=frac, ignore_index=True, replace=False if frac < 1 else True)
    # df = df.sample(n=10, ignore_index=True)

    return df


class InstructionDataset(Dataset):
    def __init__(self, tokenizer, accelerator, model_args, data_args, training_args, mode='train', dataset=None) -> None:
        super().__init__()
        self.tokenizer, self.tokenizer_dec = tokenizer
        self.accelerator = accelerator
        self.model_args = model_args
        self.data_args = data_args
        self.training_args = training_args
        self.mode = mode

        self.instructions, self.datasets = self.get_instructions(dataset)

        self.name2data = self.get_node_feature()
        self.name2edge = self.get_edge_feature()
        
    def __len__(self):
        return len(self.instructions)
    
    def get_node_feature(self):
        data_dict = {}
        with self.accelerator.main_process_first():
            for data in self.datasets:
                if data.endswith('_LP'):
                    data = data[:-3]
                data_path = f"/{data}/node_features.pt"
                if os.path.exists(data_path):
                    if data not in data_dict:
                        data_dict[data] = torch.load(data_path)
                else:
                    raise ValueError(f"no node feature {data}")

        return data_dict

    def get_edge_feature(self):
        data_dict = {}
        with self.accelerator.main_process_first():
            for data in self.datasets:
                if data.endswith('_LP'):
                    data = data[:-3]
                data_path = f"/{data}/edge_features.pt"
                if os.path.exists(data_path):
                    if data not in data_dict:
                        data_dict[data] = torch.load(data_path)
                else:
                    raise ValueError(f"no node feature {data}")

        return data_dict

    def get_instructions(self, dataset):
        if dataset:
            datasets = [dataset]
            data_weights = [0.2]
        else:
            datasets = list(self.data_args.datasets.split(','))
            data_weights = list(map(float, self.data_args.data_weights.split(',')))
        print(data_weights, datasets)

        all_instructions = []
        nums = []
        for dataset, weight in zip(datasets, data_weights):
            instruction_path = f"/{dataset}/{dataset}_dataset_{self.mode}.json"
            dataset_instructions = get_instructions(instruction_path, weight, mode=self.mode)

            all_instructions.append(dataset_instructions)
            nums.append(len(dataset_instructions))
        print(nums)
        
        return pd.concat(all_instructions, ignore_index=True), datasets

    def truncate_text(self, node_text, max_length):
        tokens = self.tokenizer.tokenize(node_text)
        if len(tokens) <= max_length:
            return node_text
        else:
            truncated_tokens = tokens[:max_length]
            truncated_text = self.tokenizer.convert_tokens_to_string(truncated_tokens)
            return truncated_text

    def __getitem__(self, idx):
        raw = self.instructions.iloc[idx]
        instruction = raw.copy()
        data_name = instruction['data']
        task_type = instruction.get('task', 'nc')
        domain_type = DOMAIN_DICT[data_name]
        mem_tokens = " ".join([f"<MEM {i}>" for i in range(0, self.model_args.memory_token_nums)])
        special_token = '<-FineTune->'

        raw_node_text = (instruction['prompt'].split('\nQuestion: ')[0]).replace(': <Node 1>,', '')
        raw_node_text = self.truncate_text(raw_node_text, 1400)

        node_text = raw_node_text + "\nQuestion: " + instruction['prompt'].split('Question: ')[-1]

        encoder_text = f"{DOMAIN_DESC[domain_type]} {TASK_DESC[task_type]}" + ' ' + mem_tokens


        if self.model_args.model_arch == 'llama3':
            prompt_text = ' ' + mem_tokens + special_token + node_text
        else:
            prompt_text = mem_tokens + special_token + ' ' + node_text

        answer_text = instruction["gpt"]
        
        # graph data
        graph = Data()
        graph.edge_index = torch.LongTensor(instruction['edge_index'])
        node_list = torch.LongTensor(instruction['node_set'])
        graph.x = self.name2data[data_name][node_list].to(dtype=torch.bfloat16 if self.training_args.bf16 else torch.float16)

        num_nodes = graph.x.shape[0]
        graph.graph_attention_mask = torch.ones([num_nodes], dtype=torch.long)
        adj = SparseTensor(row=graph.edge_index[0], col=graph.edge_index[1], sparse_sizes=(num_nodes, num_nodes))
        shortest_path_result, path = algos.floyd_warshall(adj.to_dense().long().numpy())

        shortest_path_result_truncate = np.where(shortest_path_result == 510, 0, shortest_path_result)
        max_dist = np.amax(shortest_path_result_truncate)

        edge_type_set = torch.LongTensor([-1] + instruction['edge_type_set']) # add special token for unreachable path
        graph.edge_attr = self.name2edge[data_name][edge_type_set].to(dtype=torch.bfloat16 if self.training_args.bf16 else torch.float16)
        edge_type_inv = torch.LongTensor(instruction['edge_type_inv']).unsqueeze(-1)
        attn_edge_type = torch.zeros([num_nodes, num_nodes, 1], dtype=torch.long)
        attn_edge_type[graph.edge_index[0, :], graph.edge_index[1, :]] = (edge_type_inv)

        edge_type = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy())
        graph.edge_type = torch.from_numpy(edge_type).squeeze(-1).long()

        spatial_pos = torch.from_numpy((shortest_path_result)).long()
        # truncate spatial_pos
        spatial_pos = torch.where(spatial_pos >= self.model_args.spatial_pos_max, self.model_args.spatial_pos_max - 1, spatial_pos)
        graph.rel_postion = spatial_pos # [n, n]

        out_dict = {}
        out_dict['encoder_text'] = encoder_text
        out_dict['prompt_text'] = prompt_text
        out_dict['answer_text'] = answer_text
        out_dict['graph'] = graph

        return out_dict

    def collate_fn(self, batch):
        batch_entry = {}

        # graph data
        batch_graph = collator_graph_data(batch, max_node=256)

        # text data
        encoder_texts = [entry['encoder_text'] for entry in batch]

        # encoder inputs
        self.tokenizer.truncation_side = 'left'
        encoder_inputs = self.tokenizer(
            encoder_texts,
            return_tensors="pt",
            padding="longest",
            max_length=1200,
            truncation=True,
        )
        if self.tokenizer_dec:
            t_dec = self.tokenizer_dec
        else:
            t_dec = self.tokenizer
        t_dec.truncation_side = 'right'
        # decoder inputs (instructions)
        if self.mode == 'train':
            sources = [[{'from': 'human', 'value': entry['prompt_text']}, {'from': 'gpt', 'value': entry['answer_text']}] for entry in batch]
        else:
            sources = [[{'from': 'human', 'value': entry['prompt_text']}, {'from': 'gpt', 'value': None}] for entry in batch]
        
        if self.model_args.model_arch == 'llama2':
            tokenize_fn = preprocess_llama_2
        elif self.model_args.model_arch == 'llama3':
            tokenize_fn = preprocess_llama_3
        else:
            tokenize_fn = preprocess_v1
        decoder_inputs = tokenize_fn(sources, t_dec, self.mode, answers=[entry['answer_text'] for entry in batch] if self.mode != 'train' else None, mem_size=self.model_args.memory_token_nums)

        # collect
        batch_entry['graph_embeds'] = batch_graph.x # tensor
        batch_entry['graph_attention_mask'] = batch_graph.graph_attention_mask # tensor
        batch_entry['rel_position']= batch_graph.spatial_pos # tensor
        batch_entry['edge_attr'] = batch_graph.edge_attr
        batch_entry['edge_type'] = batch_graph.edge_type
        batch_entry['input_ids'] = encoder_inputs.input_ids
        batch_entry['attention_mask'] = encoder_inputs.attention_mask
        batch_entry['prompt_answer_ids'] = decoder_inputs['input_ids']
        batch_entry['prompt_attention_mask'] = None if self.mode == 'train' else decoder_inputs['attention_mask']
        batch_entry['text_length'] = encoder_inputs['attention_mask'].sum(dim=-1)
        batch_entry['labels'] = decoder_inputs['labels']

        return batch_entry