import torch
import numpy as np
import json
import torch.nn as nn

from utils import batch_by_size
from utils import calculate_recall_ndcg_at_ks
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import Adam
from models import EmerGNN
from sklearn.metrics import roc_auc_score, average_precision_score
import pickle

class BaseModel(object):
    def __init__(self, eval_ent, eval_rel, args, entity_vocab=None, relation_vocab=None):
        self.model = EmerGNN(eval_ent, eval_rel, args)
        if args.load_model:
            state_dict = torch.load(args.dataset + '_saved_model.pt')
            self.model.load_state_dict(state_dict)
        self.model.cuda()

        self.eval_ent = eval_ent
        self.eval_rel = eval_rel
        self.all_rel = args.all_rel
        self.args = args

        self.optimizer = Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.lamb)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max')

        self.bce_loss = nn.BCELoss()

        # self.entity_vocab = entity_vocab
        # self.relation_vocab = relation_voc
        import os
        base_dir = os.path.dirname(os.path.abspath(__file__))
        with open(os.path.join(base_dir, 'data', 'TS_all_node.json'), 'r') as f:
            entity_vocab = json.load(f)
            # 将键转换为整数，以便直接使用整数id进行查找
            self.entity_vocab = {int(k): v.get('name', 'Unknown') for k, v in entity_vocab.items()}
        
        with open(os.path.join(base_dir, 'data', 'TS_all_rel.json'), 'r') as f:
            relation_vocab = json.load(f)
            # 提取'name'字段，并将键转换为整数
            self.relation_vocab = {int(k): v.get('name', 'Unknown') for k, v in relation_vocab.items()}


    def train(self, train_pos, train_neg, KG):
        pos_head, pos_tail, pos_label = torch.LongTensor(train_pos[:,0]).cuda(), torch.LongTensor(train_pos[:,1]).cuda(), torch.FloatTensor(train_pos[:,2:]).cuda()
        neg_head, neg_tail, neg_label = torch.LongTensor(train_neg[:,0]).cuda(), torch.LongTensor(train_neg[:,1]).cuda(), torch.FloatTensor(train_neg[:,2:]).cuda()
        
        n_train = len(pos_head)
        n_batch = self.args.n_batch

        self.model.train()
        for p_h, p_t, p_r, n_h, n_t, n_r in tqdm(batch_by_size(n_batch, pos_head, pos_tail, pos_label, neg_head, neg_tail, neg_label, n_sample=n_train), 
                ncols=100, leave=False, total=len(pos_head)//n_batch+int(len(pos_head)%n_batch>0)):
            self.model.zero_grad()
            p_scores = torch.sigmoid(self.model.enc_r(self.model.enc_ht(p_h, p_t, KG)))
            # n_scores = torch.sigmoid(self.model.enc_r(self.model.enc_ht(n_h, n_t, KG)))
            p_r = p_r.float()
            
            # without negative samples
            loss = self.bce_loss(p_scores, p_r)
            loss.backward()
            self.optimizer.step()

    def save_present(self, train_pos, train_neg, KG):
        pos_head, pos_tail, pos_label = torch.LongTensor(train_pos[:,0]).cuda(), torch.LongTensor(train_pos[:,1]).cuda(), torch.FloatTensor(train_pos[:,2:]).cuda()
        neg_head, neg_tail, neg_label = torch.LongTensor(train_neg[:,0]).cuda(), torch.LongTensor(train_neg[:,1]).cuda(), torch.FloatTensor(train_neg[:,2:]).cuda()
        n_train = len(pos_head)
        n_batch = self.args.n_batch

        # Initialize lists to collect data
        h_list = []
        t_list = []
        r_list = []
        ht_embed_list = []

        for p_h, p_t, p_r, n_h, n_t, n_r in tqdm(batch_by_size(n_batch, pos_head, pos_tail, pos_label, neg_head, neg_tail, neg_label, n_sample=n_train), 
                ncols=100, leave=False, total=len(pos_head)//n_batch+int(len(pos_head)%n_batch>0)):
            
            ht_embed = self.model.enc_ht(p_h, p_t, KG)

            h_list.append(p_h.cpu().data.numpy())
            t_list.append(p_t.cpu().data.numpy())
            r_list.append(p_r.cpu().data.numpy())
            ht_embed_list.append(ht_embed.cpu().data.numpy())
        
        h_array = np.concatenate(h_list)
        t_array = np.concatenate(t_list)
        r_array = np.concatenate(r_list)
        ht_embed_array = np.concatenate(ht_embed_list, axis=0)
        if 'S1' in self.args.dataset:
            dataset = 'S1'
        elif 'S2' in self.args.dataset:
            dataset = 'S2'
        else:
            dataset = 'S0'
        np.savez(f'{dataset}-training_data.npz', h=h_array, t=t_array, r=r_array, ht_embed=ht_embed_array)

    def knn(self, test_pos, KG, record=0, filename=None):
        # pos_head, pos_tail, pos_label = test_pos[:,0], test_pos[:,1], test_pos[:,2:]
        # batch_size = self.args.test_batch_size
        # num_batch = len(pos_head) // batch_size + int(len(pos_head)%batch_size>0)

        if 'S1' in self.args.dataset:
            dataset = 'S1'
        elif 'S2' in self.args.dataset:
            dataset = 'S2'
        else:
            dataset = 'S0'
        self.model.eval()
        training_data = np.load(f'{dataset}-training_data.npz')
        train_ht_embed = torch.FloatTensor(training_data['ht_embed']).cuda()
        train_labels = training_data['r']
        k_values = [1,5,10,20,30,50]
        pred_at_ks = {k: [] for k in k_values}
        recall_at_ks = {k: 0.0 for k in k_values}
        neighbors = []
        neighbors_labels = []
        # for i in tqdm(range(num_batch)):
        #     start = i * batch_size
        #     end = min((i+1)*batch_size, len(pos_head))
        #     batch_h = pos_head[start:end].cuda()
        #     batch_t = pos_tail[start:end].cuda()
        #     ht_embed = self.model.enc_ht(batch_h, batch_t, KG)
        test_data = np.load(f'{dataset}-test_data.npz')
        ht_embed = torch.FloatTensor(test_data['ht_embed']).cuda()
        pos_label = test_data['r']
        # print('pos_label shape: ',pos_label.shape) # (355,200)
            
        for j in range(ht_embed.size(0)):
            query = ht_embed[j].unsqueeze(0)
            similarity = torch.nn.functional.pairwise_distance(query, train_ht_embed, p=2)
            similarity = similarity.detach().cpu().numpy()
            sorted_similarity = np.argsort(similarity)
            
            neighbors.append(sorted_similarity[:20])
            neighbors_labels.append([np.nonzero(row)[0].tolist() for row in train_labels[sorted_similarity[:20]]])

            for k in k_values:
                pred = []
                for row in train_labels[sorted_similarity[:k]]:
                    pred += np.nonzero(row)[0].tolist()
                pred = list(set(pred))
                true_label = np.nonzero(pos_label[j])[0].tolist()
                recall = 0
                for r in true_label:
                    if r in pred:
                        recall += 1
                recall_at_ks[k] += recall / len(true_label) / ht_embed.size(0)
                print(len(true_label), len(pred), len(test_data), recall, recall_at_ks[k])
                if j > 10:
                    exit()

        if record > 0 :
            records = []
            for idx in range(len(neighbors)):
                neighbors_indices = neighbors[idx]
                record_entry = {
                    'id': idx,
                    'K-neighbor': neighbors_indices.tolist(),  # S1_train_data.npz中的索引
                    'K-neighbor_labels': neighbors_labels[idx]  # S1_train_data.npz中的标签
                }
                # Append the record to the records list
                records.append(record_entry)
            # Write the records to a JSON file
            with open(filename, 'w') as f:
                json.dump(records, f, indent=4)
            print(filename, 'record done')
        
        return recall_at_ks

    def finger(self, test_pos, train_file, filename):
        
        with open(train_file, 'r') as f:
            traindata = [json.loads(line) for line in f if line.strip()]
        
        with open('data/id2drug_feat.pkl', 'rb') as f:
            x = pickle.load(f, encoding='utf-8')
            mfeat = [] 
            for k in x:
                mfeat.append(x[k]['Morgan'])
            mfeat = np.array(mfeat)

        features = []
        train_labels = []

        for i in range(len(traindata)):
            entry = traindata[i]
            entry_id = entry['id']
            train_triplet = entry['test_triplet']
            h, t, r = train_triplet
            feature = np.concatenate([mfeat[h], mfeat[t]])
            features.append(feature)
            train_labels.append(r)

        train_labels = np.array(train_labels,dtype=object)
        features = np.array(features)
        features_tensor = torch.FloatTensor(features).cuda()
        features_tensor = features_tensor.reshape(len(traindata), -1)

        pos_head, pos_tail, pos_label = test_pos[:,0], test_pos[:,1], test_pos[:,2:]
        batch_size = self.args.test_batch_size
        num_batch = len(pos_head) // batch_size + int(len(pos_head)%batch_size>0)


        k_values = [1,5,10,20,30,50]
        pred_at_ks = {k: [] for k in k_values}
        recall_at_ks = {k: 0.0 for k in k_values}
        neighbors = []
        neighbors_labels = []
        
        for i in tqdm(range(len(test_pos))):
            h, t, r = test_pos[i][0], test_pos[i][1], test_pos[i][2:]
            feature = np.concatenate([mfeat[h], mfeat[t]])
            test_feat = torch.FloatTensor(feature).cuda()
            test_feat = test_feat.reshape(1, -1)
            # cos 需要把similarity逆序
            # similarity = torch.nn.functional.cosine_similarity(test_feat, features_tensor)
            # 使用torch.nn.functional.pairwise_distance
            similarity = torch.nn.functional.pairwise_distance(test_feat, features_tensor, p=2)

            # 交换顺序算一遍
            feature_ex = np.concatenate([mfeat[t], mfeat[h]])
            test_feat_ex = torch.FloatTensor(feature_ex).cuda()
            test_feat_ex = test_feat_ex.reshape(1,-1)
            similarity_ex = torch.nn.functional.pairwise_distance(test_feat_ex, features_tensor, p=2)
            similarity = torch.min(torch.stack([similarity, similarity_ex]), dim=0).values


            similarity = similarity.detach().cpu().numpy()
            sorted_similarity = np.argsort(similarity)
                
            neighbors.append(sorted_similarity[:20])
            neighbors_labels.append([row.tolist() for row in train_labels[sorted_similarity[:20]]])


            true_label = np.nonzero(r)[0].tolist()
            for k in k_values:
                pred = []
                for row in train_labels[sorted_similarity[:k]]:
                    pred += row.tolist()
                pred = list(set(pred))
                recall = 0
                for r in true_label:
                    if r in pred:
                        recall += 1
                recall_at_ks[k] += recall / len(true_label) / len(test_pos)

        records = []
        for idx in range(len(neighbors)):
            neighbors_indices = neighbors[idx]
            record_entry = {
                'id': idx,
                'K-neighbor': neighbors_indices.tolist(),  # S1_train_data.npz中的索引
                'K-neighbor_labels': neighbors_labels[idx]  # S1_train_data.npz中的标签
            }
            # Append the record to the records list
            records.append(record_entry)
        # Write the records to a JSON file
        with open(filename, 'w') as f:
            json.dump(records, f, indent=4)
        print(filename, 'record done')
        
        return recall_at_ks

    def descript(self, test_pos, train_file, filename):
        
        with open(train_file, 'r') as f:
            traindata = [json.loads(line) for line in f if line.strip()]
        
        des_emb = np.load('TS_drug_emb.npz')
        mfeat= des_emb['emb']

        features = []
        train_labels = []

        for i in range(len(traindata)):
            entry = traindata[i]
            entry_id = entry['id']
            train_triplet = entry['test_triplet']
            h, t, r = train_triplet
            feature = np.concatenate([mfeat[h], mfeat[t]])
            features.append(feature)
            train_labels.append(r)

        train_labels = np.array(train_labels,dtype=object)
        features = np.array(features)
        features_tensor = torch.FloatTensor(features).cuda()
        features_tensor = features_tensor.reshape(len(traindata), -1)

        pos_head, pos_tail, pos_label = test_pos[:,0], test_pos[:,1], test_pos[:,2:]
        batch_size = self.args.test_batch_size
        num_batch = len(pos_head) // batch_size + int(len(pos_head)%batch_size>0)


        k_values = [1,5,10,20,30,50]
        pred_at_ks = {k: [] for k in k_values}
        recall_at_ks = {k: 0.0 for k in k_values}
        neighbors = []
        neighbors_labels = []
        
        for i in tqdm(range(len(test_pos))):
            h, t, r = test_pos[i][0], test_pos[i][1], test_pos[i][2:]
            feature = np.concatenate([mfeat[h], mfeat[t]])
            test_feat = torch.FloatTensor(feature).cuda()
            test_feat = test_feat.reshape(1, -1)
            # cos 需要把similarity逆序
            # similarity = torch.nn.functional.cosine_similarity(test_feat, features_tensor)
            # 使用torch.nn.functional.pairwise_distance
            similarity = torch.nn.functional.pairwise_distance(test_feat, features_tensor, p=2)

            # 交换顺序算一遍
            feature_ex = np.concatenate([mfeat[t], mfeat[h]])
            test_feat_ex = torch.FloatTensor(feature_ex).cuda()
            test_feat_ex = test_feat_ex.reshape(1,-1)
            similarity_ex = torch.nn.functional.pairwise_distance(test_feat_ex, features_tensor, p=2)
            similarity = torch.min(torch.stack([similarity, similarity_ex]), dim=0).values


            similarity = similarity.detach().cpu().numpy()
            sorted_similarity = np.argsort(similarity)
                
            neighbors.append(sorted_similarity[:20])
            neighbors_labels.append(train_labels[sorted_similarity[:20]].tolist())

            r = r.cpu().numpy()
            true_label = np.nonzero(r)[0].tolist()
            # print(true_label)
            for k in k_values:
                pred = []
                for row in train_labels[sorted_similarity[:k]]:
                    pred += row
                pred = list(set(pred))
                recall = 0
                for r in true_label:
                    if r in pred:
                        recall += 1
                recall_at_ks[k] += recall / len(true_label) / len(test_pos)
            # if i < 5:
            #     print(h, t, true_label)
            #     # print(i, recall, len(true_label), len(test_pos), recall_at_ks[k])
            # else:
            #     exit()

        records = []
        for idx in range(len(neighbors)):
            neighbors_indices = neighbors[idx]
            record_entry = {
                'id': idx,
                'K-neighbor': neighbors_indices.tolist(),  # S1_train_data.npz中的索引
                'K-neighbor_labels': neighbors_labels[idx]  # S1_train_data.npz中的标签
            }
            # Append the record to the records list
            records.append(record_entry)
        # Write the records to a JSON file
        with open(filename, 'w') as f:
            json.dump(records, f, indent=4)
        print(filename, 'record done')
        
        return recall_at_ks


    def evaluate(self, test_pos, test_neg, KG, record=0, filename=None, save_pre=0):
        pos_head, pos_tail, pos_label = test_pos[:,0], test_pos[:,1], test_pos[:,2:]
        # neg_head, neg_tail, neg_label = test_neg[:,0], test_neg[:,1], test_neg[:,2:]
        batch_size = self.args.test_batch_size
        num_batch = len(pos_head) // batch_size + int(len(pos_head)%batch_size>0)
        
        # Initialize lists to collect data
        h_list = []
        t_list = []
        r_list = []
        ht_embed_list = []

        self.model.eval()
        pos_scores = []
        neg_scores = []
        pred_class = {}
        for i in range(num_batch):
            start = i * batch_size
            end = min((i+1)*batch_size, len(pos_head))
            p_h= pos_head[start:end]
            p_t= pos_tail[start:end]
            ht_embed = self.model.enc_ht(p_h, p_t, KG)
            p_scores = self.model.enc_r(ht_embed)
            p_scores = torch.sigmoid(p_scores)
            
            if save_pre:
                h_list.append(p_h.cpu().numpy())
                t_list.append(p_t.cpu().numpy())
                batch_r = pos_label[start:end]
                r_list.append(batch_r.cpu().numpy())
                ht_embed_list.append(ht_embed.detach().cpu().numpy())

            # n_h= neg_head[start:end]
            # n_t= neg_tail[start:end]
            # n_scores = self.model.enc_/r(self.model.enc_ht(n_h, n_t, KG))
            # n_scores = torch.sigmoid(n_scores)
            pos_scores.append(p_scores.cpu().data.numpy())
            # neg_scores.append(n_scores.cpu().data.numpy())

        labels = pos_label.cpu().data.numpy()
        pos_scores = np.concatenate(pos_scores)
        # neg_scores = np.concatenate(neg_scores)

        k_values = [5,10,20]
        recall_at_ks, ndcg_at_ks = calculate_recall_ndcg_at_ks(labels, pos_scores, k_values)


        if record > 0:
            # 对于每个test case，记录其id、预测的前20个答案（包含relation id、概率）
            # 记录在一个json文件里 
            records = []
            for idx in range(len(pos_scores)):
                # Get the probability distribution for the current test case
                prob_dist = pos_scores[idx]
                # Get indices of relations sorted by probability in descending order
                sorted_indices = np.argsort(prob_dist)[::-1]
                # Get the top 10 relation IDs and their probabilities
                top_10_indices = sorted_indices[:20]
                top_10_probs = prob_dist[top_10_indices]
                # Create a list of dictionaries for the top 10 predictions
                predictions = []
                for rel_id, prob in zip(top_10_indices, top_10_probs):
                    predictions.append({
                        'relation_id': int(rel_id),
                        'probability': int(float(prob)*1000)/1000
                    })
                # Create a record for the current test case
                record_entry = {
                    'id': idx,
                    'predictions': predictions
                }
                # Append the record to the records list
                records.append(record_entry)
            # Write the records to a JSON file
            with open(filename, 'w') as f:
                json.dump(records, f, indent=4)
        
        if save_pre:
            # Concatenate lists into numpy arrays
            h_array = np.concatenate(h_list)
            t_array = np.concatenate(t_list)
            r_array = np.concatenate(r_list)
            ht_embed_array = np.concatenate(ht_embed_list, axis=0)
            print(h_array.shape, ht_embed_array.shape) 
            
            # Save the data using numpy .npz format
            if 'S1' in self.args.dataset:
                dataset = 'S1'
            elif 'S2' in self.args.dataset:
                dataset = 'S2'
            else:
                dataset = 'S0'
            np.savez(f'{dataset}-test_data.npz', h=h_array, t=t_array, r=r_array, ht_embed=ht_embed_array)     

        # return np.mean(roc_auc), np.mean(prc_auc), np.mean(ap)
        return recall_at_ks, ndcg_at_ks

    def test_single(self, triplet, KG):
        heads = triplet[0].unsqueeze(0)
        tails = triplet[1].unsqueeze(0)
        ht_embed = self.model.enc_ht(heads, tails, KG)
        scores = self.model.enc_r(ht_embed)
        rela_scores = torch.sigmoid(scores).data.cpu().numpy()

        pred = (rela_scores > 0.5).astype('float')
        return pred[0]

    def visualize(self, triplet, KG, head_batch=True):
        h, t, r = triplet[0], triplet[1], triplet[2:]
        paths, weights = self.model.visualize_forward(h.unsqueeze(0), t.unsqueeze(0), r.unsqueeze(0), KG, 15, head_batch)
        outputs = []
        rel_weights = [0] * (self.all_rel - self.eval_rel)
        rel_freq = [0] * self.all_rel
        for path, weight in zip(paths, weights):
            out_str = '%4f\t' % weight
            for i in range(len(path)):
                h, t, r = path[i]
                h_name = self.entity_vocab[h]
                t_name = self.entity_vocab[t]
                if r == 2*self.all_rel - self.eval_rel:
                    r_name = 'idd'
                    continue
                # else:
                    # r_mod = r % self.all_rel
                    # if r_mod >= self.eval_rel:
                        # r_name = self.relation_vocab[r_mod]
                    # else:
                        # r_name = str(r % self.all_rel)
                    # rel_freq[r_mod] += 1

                if r >= self.all_rel and r < 2*self.all_rel - self.eval_rel:
                    # r_name += "_inv"
                    r_name = self.relation_vocab[r-self.all_rel+self.eval_rel]
                    out_str += f'({t_name}, {r_name}, {h_name}),'
                else:
                    r_name = self.relation_vocab[r]
                    out_str += f'({h_name}, {r_name}, {t_name}),'

                # if r >= self.eval_rel and r < self.all_rel:
                    # rel_weights[r-self.eval_rel] += 1
                # elif r >= self.all_rel+self.eval_rel and r < 2*self.all_rel:
                    # rel_weights[r-self.eval_rel-self.all_rel] += 1

                # if i == 0:
                    # out_str += '< %s, %6s, %18s' % (h_name, r_name, t_name)   # 打印路径
                # else:
                    # out_str += ', %6s, %18s' % (r_name, t_name)
            out_str += '\n'
            outputs.append(out_str)
        return outputs, np.array(rel_weights), np.array(rel_freq)

    def save_model(self, out_str=''):
        torch.save(self.model.state_dict(), self.args.dataset+'_saved_model-314.pt')
        print(out_str, 'model saved')

    def load_model(self, model_name='_saved_model.pt'):
        self.model.load_state_dict(torch.load(model_name))
        print(model_name, 'model loaded')