from __future__ import print_function
from matplotlib.pyplot import axis
from numpy.lib.function_base import append
import random
import torch
from torch import logit
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import config as cf
from datasets import ImagenetNoise

import torchvision.transforms as transforms

import os
import argparse

from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from PIL import Image

from utils import prepare_dset, maha
from networks import *

from utils import get_pretrained_model
import matplotlib.pyplot as plt

import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np


from sklearn.preprocessing import KBinsDiscretizer
import numpy as np

import os
import torch

import sys
sys.path.append('/home/username/piusername/MUSK/')
os.environ['HF_HOME'] = '/home/username/scratch/'
os.environ['HF_TOKEN'] = ''

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
    #  torch.backends.cudnn.deterministic = True
setup_seed(20)

parser = argparse.ArgumentParser(description='Ensemble Training')
# pretrained models setting
parser.add_argument('--maha_file', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50')
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--pretrained_model', default='vit', type=str, help='SSL feature map type')


parser.add_argument('--batch_size', default=1024, type=int)
parser.add_argument('--dataset', default='cifar10', type=str, help='cifar10/cifar100')
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--random_state', type=int, default=0)




parser.add_argument('--ynoise_type', default='symmetric', type=str, help='symmetric/pairflip')
parser.add_argument('--ynoise_rate', default=0.0, type=float, help='label noise rate')
parser.add_argument('--xnoise_type', default='blur', type=str, help='gaussian/blur')
parser.add_argument('--xnoise_arg', default=1, type=float)
parser.add_argument('--xnoise_rate', default=0.0, type=float)
parser.add_argument('--trigger_size', type=int, default=3)
parser.add_argument('--trigger_ratio', type=float, default=0.)


parser.add_argument('--select_dataloader', type=str, default="trainloader")


args = parser.parse_args()
args.scale=False

num_classes = args.num_classes

# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
batch_size = args.batch_size
# Custom_Dataset class
class Custom_Dataset(Dataset):
    def __init__(self, x, y, data_set, transform=None):
        self.x_data = x
        self.y_data = y
        self.data = data_set
        self.transform = transform

    def __len__(self):
        return len(self.x_data)

    # return idx
    def __getitem__(self, idx):
        if self.data == 'cifar':
            img = Image.fromarray(self.x_data[idx])
        elif self.data == 'svhn':
            img = Image.fromarray(np.transpose(self.x_data[idx], (1, 2, 0)))

        x = self.transform(img)

        return x, self.y_data[idx], idx


def predict_maha(model, epoch, args):
    feature_all = []
    labels_all = []
    for i in range(epoch):
        print(i)
        with torch.no_grad():
        # for batch_idx, ((inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
            for batch_idx, ( inputs, targets) in enumerate(eval(args.select_dataloader)):
                logits = inputs.cpu().data.numpy()
                labels_np = targets.T[0].cpu().data.numpy()
                    
                if batch_idx == 0 and i == 0:
                    feature_all = logits
                    labels_all = labels_np
                else:
                    feature_all = np.concatenate((feature_all,logits),axis=0)
                    labels_all = np.concatenate((labels_all,labels_np),axis=0)
                print(feature_all.shape)
                # print(num_classes)
    # print(num_classes)
    maha_intermediate_dict = maha(feature_all,labels_all,indist_classes = num_classes)
    print(feature_all.shape)
    np.save(args.maha_file, maha_intermediate_dict)





def compute_mean_cov(pretrain_model,args):
    
    pretrain_model.cuda()
    if not args.arch.startswith('clip'):
        pretrain_model = torch.nn.DataParallel(pretrain_model)
        pretrain_model.eval()
    # train for N epochs
    predict_maha(pretrain_model, 3, args)
# Data Uplaod
print('\n[Phase 1] : Data Preparation')


import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import scanpy as sc
import pandas as pd
import pickle

import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

import time
import skimage.io
import numpy as np
import pandas as pd
import cv2
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm_notebook as tqdm

import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import uniform, choice, normal
from torch import nn, optim, Tensor, manual_seed, argmax
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics.classification import Accuracy, MulticlassConfusionMatrix
from pytorch_lightning.utilities.model_summary import ModelSummary
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from torch.autograd import Variable
import pandas as pd
import torch

import os
import torch

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

from PIL import Image
from matplotlib import cm
import numpy as np

import pickle

import timm
from PIL import Image
from torchvision import transforms
import torch

# pretrained=True needed to load UNI weights (and download weights for the first time)
# using UNI2-h as example
timm_kwargs = {
   'img_size': 224, 
   'patch_size': 14, 
   'depth': 24,
   'num_heads': 24,
   'init_values': 1e-5, 
   'embed_dim': 1536,
   'mlp_ratio': 2.66667*2,
   'num_classes': 0, 
   'no_embed_class': True,
   'mlp_layer': timm.layers.SwiGLUPacked, 
   'act_layer': torch.nn.SiLU, 
   'reg_tokens': 8, 
   'dynamic_img_size': True
  }
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()


import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import login

from PIL import Image
from matplotlib import cm
import numpy as np

import pickle

class REGDataset(Dataset):
    def __init__(self,
                 adata,
                 image,
                 transform,
                 gene = 'NOSTRIN'
                ):

        self.adata = adata[:,gene]
        self.image = image
        self.transform = transform

    def __len__(self):
        return self.adata.shape[0]

    def __getitem__(self, index):
        input_data = self.image[index]
        input_data = torch.from_numpy(input_data)
        output_data = torch.FloatTensor(self.adata[str(index)].X)[0,:]
        
        return input_data, output_data

# ##test idc
# dataset_name = 'IDC'
# adata_train = sc.read("/home/username/scratch/IDC_train.h5ad")
# adata_valid = sc.read("/home/username/scratch/IDC_valid.h5ad")
# adata_test = sc.read("/home/username/scratch/IDC_test.h5ad")

# adata_train.obs_names =[str(i) for i in range(len(adata_train))]
# adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
# adata_test.obs_names =[str(i) for i in range(len(adata_test))]


# image_train = np.load("/home/username/scratch/IDC_train_imagefeature.npy")
# image_valid = np.load("/home/username/scratch/IDC_valid_imagefeature.npy")
# image_test = np.load("/home/username/scratch/IDC_test_imagefeature.npy")
# gene_list = pd.read_json(f"/home/username/scratch/hest-bench/{dataset_name}/var_50genes.json")
    
# #test brain
# adata_train = sc.read("/home/username/scratch/adata_brain_norm_train.h5ad")
# adata_valid = sc.read("/home/username/scratch/adata_brain_norm_valid.h5ad")
# adata_test = sc.read("/home/username/scratch/adata_brain_norm_test.h5ad")

# adata_train.obs_names =[str(i) for i in range(len(adata_train))]
# adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
# adata_test.obs_names =[str(i) for i in range(len(adata_test))]


# image_train = np.load("/home/username/scratch/train_data/brain_train_imagefeature_univ2.npy")
# image_valid = np.load("/home/username/scratch/train_data/brain_valid_imagefeature_univ2.npy")
# image_test = np.load("/home/username/scratch/train_data/brain_test_imagefeature_univ2.npy")

# gene_list = {"genes": [
#         "ACTA2",
#         "BST2",
#         "CCND1",
#         "COL18A1",
#         "COL4A1",
#         "COL4A2",
#         "COL6A1",
#         "CPD",
#         "CREG1",
#         "CTSH",
#         "DDIT4",
#         "DNAJB1",
#         "ENG",
#         "ERO1A",
#         "FN1",
#         "FSTL1",
#         "FURIN",
#         "GLRX",
#         "HSPA1A",
#         "IER3",
#         "IFI27",
#         "IGFBP4",
#         "IGFBP7",
#         "ISG15",
#         "LGALS3",
#         "MFGE8",
#         "MFSD12",
#         "MGP",
#         "MT1X",
#         "MYL9",
#         "NDRG1",
#         "NFKBIA",
#         "NGRN",
#         "NOTCH3",
#         "NUPR1",
#         "PFKP",
#         "PI4KA",
#         "PLXND1",
#         "RPL11",
#         "RPS3",
#         "SNHG25",
#         "SOD2",
#         "SPTSSA",
#         "TAGLN",
#         "TAP1",
#         "THY1",
#         "TPM1",
#         "TSPYL1",
#         "VEGFA",
#         "XBP1"
#     ]}




# # dataset_name = "/home/username/scratch/train_data/"
# dataset_name = 'CCRCC'
# sample_list = ['READ', 'PRAD', 'LYMPH_IDC', 'COAD', 'CCRCC']

# # gene_list = {"genes": ["ABCC11", "ADH1B", "ADIPOQ", "ANKRD30A", "AQP1", "AQP3", "CCR7", "CD3E", "CEACAM6", "CEACAM8", "CLIC6", "CYTIP", "DST", "ERBB2", "ESR1", "FASN", "GATA3", "IL2RG", "IL7R", "KIT", "KLF5", "KRT14", "KRT5", "KRT6B", "MMP1", "MMP12", "MS4A1", "MUC6", "MYBPC1", "MYH11", "MYLK", "OPRPN", "OXTR", "PIGR", "PTGDS", "PTN", "PTPRC", "SCD", "SCGB2A1", "SERHL2", "SERPINA3", "SFRP1", "SLAMF7", "TACSTD2", "TCL1A", "TENT5C", "TOP2A", "TPSAB1", "TRAC", "VWF"]}

# gene_list = pd.read_json(f"/home/username/scratch/hest-bench/{dataset_name}/var_50genes.json")

# adata_train = sc.read(f"/home/username/scratch/train_data/{dataset_name}_norm_train.h5ad")
# adata_valid = sc.read(f"/home/username/scratch/train_data/{dataset_name}_norm_valid.h5ad")
# adata_test = sc.read(f"/home/username/scratch/train_data/{dataset_name}_norm_test.h5ad")

# adata_train.obs_names =[str(i) for i in range(len(adata_train))]
# adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
# adata_test.obs_names =[str(i) for i in range(len(adata_test))]


# image_train = np.load(f"/home/username/scratch/train_data/{dataset_name}_train_imagefeature_univ2.npy")
# image_valid = np.load(f"/home/username/scratch/train_data/{dataset_name}_valid_imagefeature_univ2.npy")
# image_test = np.load(f"/home/username/scratch/train_data/{dataset_name}_test_imagefeature_univ2.npy")


dataset_name = "skin"
gene_list = {"genes": [
        "KRT15",
        "CRCT1",
        "RHCG",
        "IGHG1",
        "GGCT",
        "ASPRV1",
        "JCHAIN",
        "PI3",
        "SLURP1",
        "CA2",
        "SSFA2",
        "SPRR2E",
        "LCE3D",
        "MT1X",
        "IGHG4",
        "S100A7A",
        "HERC6",
        "WARS",
        "RPL22L1",
        "SDR16C5",
        "IL1RN",
        "MX1",
        "GGH",
        "CLEC2B",
        "CAST",
        "PTGS1",
        "IGFBP6",
        "PRSS8",
        "LAP3",
        "TXNL4A",
        "SERPINE2",
        "IGLC2",
        "KRT75",
        "C9orf3",
        "IGHG3",
        "SPRR2D",
        "IFIT1",
        "PPL",
        "LSM5",
        "TMEM256",
        "KRT2",
        "SPRR2G",
        "SNRPD1",
        "SPRR2A",
        "OAS1",
        "GBA",
        "GADD45GIP1",
        "LYZ",
        "SPINT1",
        "NAGK"
    ],}

adata_train = sc.read("/home/username/scratch/adata_skin_norm_train.h5ad")
adata_valid = sc.read("/home/username/scratch/adata_skin_norm_valid.h5ad")
adata_test = sc.read("/home/username/scratch/adata_skin_norm_test.h5ad")

adata_train.obs_names =[str(i) for i in range(len(adata_train))]
adata_valid.obs_names =[str(i) for i in range(len(adata_valid))]
adata_test.obs_names =[str(i) for i in range(len(adata_test))]


image_train = np.load("/home/username/scratch/train_data/skin_train_imagefeature_univ2.npy")
image_valid = np.load("/home/username/scratch/train_data/skin_valid_imagefeature_univ2.npy")
image_test = np.load("/home/username/scratch/train_data/skin_test_imagefeature_univ2.npy")

adata_train.obs['batch'] = ['0']*3838 + ['1']*3650




from sklearn.decomposition import PCA
pca = PCA(n_components=20, random_state=0)
adata_train.obsm['X_pca'] = pca.fit_transform(adata_train[:,gene_list['genes']].X)
adata_valid.obsm['X_pca'] = pca.fit_transform(adata_valid[:,gene_list['genes']].X)
adata_test.obsm['X_pca'] = pca.fit_transform(adata_test[:,gene_list['genes']].X)

if len(set(adata_train.obs['batch'])) > 1:
    sc.external.pp.harmony_integrate(adata_train, key='batch', random_state=0)
    adata_train.obsm['X_pca'] = adata_train.obsm['X_pca_harmony'].astype('double')
else:
    adata_train.obsm['X_pca'] = adata_train.obsm['X_pca'].astype('double')


import sklearn.cluster
import numpy as np
from sklearn.datasets import make_blobs
import sklearn.cluster
from sklearn.cluster import KMeans
from sklearn.preprocessing import KBinsDiscretizer
import numpy as np
# Initialize the discretizer
from sklearn.preprocessing import KBinsDiscretizer
import numpy as np
# Initialize the discretizer
n_bins = int(np.max(adata_train[:,gene_list['genes']].X)) + 1
from sklearn.metrics import silhouette_samples, silhouette_score
search_data = {}
for test_cluster in range(3,15):
    kbd = sklearn.cluster.KMeans(n_clusters=test_cluster, random_state=0, n_init="auto").fit(adata_train.obsm['X_pca'])
    pred_label = kbd.predict(adata_train.obsm['X_pca'])
    avg_sil = silhouette_score(adata_train.obsm['X_pca'], pred_label)
    search_data[test_cluster] = avg_sil

max_avg = np.argmax(list(search_data.values()))
n_bins = list(search_data.keys())[max_avg]
# n_bins = 3
kbd = sklearn.cluster.KMeans(n_clusters=n_bins, random_state=0, n_init="auto").fit(adata_train.obsm['X_pca'])
print("optimal clustering number", n_bins)
adata_train.obsm['bin'] = kbd.predict(adata_train.obsm['X_pca'])
adata_valid.obsm['bin'] = kbd.predict(adata_valid.obsm['X_pca'].astype('double'))
adata_test.obsm['bin'] = kbd.predict(adata_test.obsm['X_pca'].astype('double'))
num_classes = len(set(adata_train.obsm['bin']))

class REGDataset(Dataset):
    def __init__(self,
                 adata,
                 image,
                 transform,
                 gene = 'NOSTRIN'
                ):

        self.adata = adata[:,gene]
        self.image = image
        self.transform = transform

    def __len__(self):
        return self.adata.shape[0]

    def __getitem__(self, index):
        input_data = self.image[index]
        input_data = torch.from_numpy(input_data)
        output_label = torch.FloatTensor(self.adata[str(index)].obsm['bin']).long()
        
        return input_data, output_label
    
    

train_dataset_new = REGDataset(adata_train, image_train, transform, gene_list['genes'])
valid_dataset = REGDataset(adata_valid, image_valid, transform, gene_list['genes'])
test_dataset = REGDataset(adata_test, image_test, transform, gene_list['genes'])

trainloader = torch.utils.data.DataLoader(train_dataset_new, batch_size=512, shuffle=False, drop_last=False)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=512, shuffle=False, drop_last=False)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=False, drop_last=False)

compute_mean_cov(model,args)

