import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import sys
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import model
import torch
import os
import dataset_drug
import time
from torch.utils.data.dataloader import DataLoader
import torch.nn as nn
import numpy as np
#import scipy.stats
#from apex import amp
#from apex.parallel import convert_syncbn_model

import random

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(30)

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'


def mean_squared_error(target, prediction):
    target_array = np.asarray(target)
    prediction_array = np.asarray(prediction)
    return np.mean(np.square(target_array - prediction_array))

def mean_absolute_error(target, prediction):
    target_array = np.asarray(target)
    prediction_array = np.asarray(prediction)
    return np.mean(np.abs(target_array - prediction_array))


def accuracy(target, prediction):
    if isinstance(target[0], int):
        # non-sequence case
        return np.mean(np.asarray(target) == np.asarray(prediction).argmax(-1))
    else:
        correct = 0
        total = 0
        for label, score in zip(target, prediction):
            label_array = np.asarray(label)
            pred_array = np.asarray(score).argmax(-1)
            mask = label_array != -1
            is_correct = label_array[mask] == pred_array[mask]
            correct += is_correct.sum()
            total += is_correct.size
        return correct / total

if __name__ == '__main__':
    data_path = './data/downstream'
    epochs = 60
    batch_size = 16
    
    a3m_dir='/home/public/bigdata/my/datasets/drug/a3m/train'
    filenames = [
        os.path.join(a3m_dir,name) for name in os.listdir(a3m_dir)
        if os.path.splitext(name)[-1] == '.a3m'
    ]  #选择指定目录下的.png图片
    drug_train_data = dataset_drug.Dataset(filenames)
    drug_train_loader = DataLoader(
        drug_train_data, batch_size=batch_size, shuffle=True, collate_fn=drug_train_data.collate_fn
    )
    
    a3m_dir_test='/home/public/bigdata/my/datasets/drug/a3m/test'
    filenames = [
        os.path.join(a3m_dir_test,name) for name in os.listdir(a3m_dir_test)
        if os.path.splitext(name)[-1] == '.a3m'
    ]  #选择指定目录下的.png图片
    drug_test_data = dataset_drug.Dataset(filenames)
    drug_test_loader = DataLoader(
        drug_test_data, batch_size=batch_size, shuffle=True, collate_fn=drug_test_data.collate_fn
    )
    
    downstream_model = model.model_down_tape.ProteinBertForSequenceClassification().cuda()
    downstream_model.load_state_dict(torch.load('save/downstream/best_drug_ori.pt')['model_state_dict'],strict=False)
    #downstream_model=torch.nn.DataParallel(downstream_model)

    #downstream_model = convert_syncbn_model(downstream_model)
    #downstream_model, optimizer = amp.initialize(downstream_model, optimizer, opt_level='O0')
    a,b=[],[]
    downstream_model.eval()
    for idx, batch in enumerate(drug_test_loader):
        drug_inputs = batch[0]
        drug_targets = batch[1]
        drug_inputs, drug_targets = drug_inputs.cuda(), drug_targets.cuda()
        with torch.no_grad():
            outputs = downstream_model(drug_inputs)[0]
            #value_prediction = outputs['representations'][33][:,0,:].squeeze().cpu().numpy()
        a.extend(outputs.cpu().numpy())
        b.extend(drug_targets.cpu().numpy())
        if(idx==256):break

from anndata import AnnData
import scanpy as sc

def plot_umap(adata, namespace='flu'):
    sc.pl.umap(adata, color='label1',save='label1.png'.format(namespace))

adata = AnnData(np.array(a))
obs = {}
obs["label1"] = []
#adata.obs['label1'] = np.array(b)
obs["label1"].extend(np.array(b))
for key in obs:
        #print(key) #n_seq seq Name #Sequence Accession Complete Genome Segment Segment_Length Subtype  Collection_Date
        #Host_Species Country State/Province Flu_Season Strain_Name
    adata.obs[key] = obs[key]
sc.pp.neighbors(adata, n_neighbors=100, use_rep='X')#, n_neighbors=100,
sc.tl.louvain(adata, resolution=1.)

sc.set_figure_params(dpi_save=500)

sc.tl.umap(adata, min_dist=1.)
plot_umap(adata)
