import torch
import pickle
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
from torch.nn import functional as F
import torch.nn.utils.prune as prune
import time
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt

from mmids.RNA_ATAC_networks import RNA_ATAC_mixVAE


class validat_cplmixVAE:
    def __init__(self, saving_folder='', device=None, eps=1e-8, save_flag=True):

        self.eps = eps
        self.save = save_flag
        self.folder = saving_folder
        self.device = device
        if device is None:
            self.gpu = False
            print('using CPU ...')
        else:
            self.gpu = True
            torch.cuda.set_device(device)
            gpu_device = torch.device('cuda:' + str(device))
            print('using GPU ' + torch.cuda.get_device_name(torch.cuda.current_device()))



    def eval_model(self, data_T, data_E, data_M, modalities, batch_size=1000):

            data_set_troch_T = torch.FloatTensor(data_T)
            data_set_troch_E = torch.FloatTensor(data_E)
            data_set_troch_M = torch.FloatTensor(data_M)

            indx_set_troch = torch.FloatTensor(np.arange(data_T.shape[0]))
            all_data = TensorDataset(data_set_troch_T, data_set_troch_E, data_set_troch_M, indx_set_troch)
            self.batch_size = batch_size

            data_loader = DataLoader(all_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
            bias = self.model.qc_T.bias.detach().cpu().numpy()
            self.pruning_mask = np.where(bias != 0.)[0]
            prune_indx = np.where(bias == 0.)[0]
            max_len = len(data_loader.dataset)

            state_sample = dict()
            state_mu = dict()
            state_var = dict()
            z_prob = dict()
            z_sample = dict()
            state_cat = dict()
            prob_cat = dict()
            predicted_label = dict()
            lowD_x = dict()

            for mod in ['T', 'ME']:
                state_sample[mod] = []
                state_mu[mod] = []
                state_var[mod] = []
                z_prob[mod] = []
                z_sample[mod] = []
                state_cat[mod] = []
                prob_cat[mod] = []
                predicted_label[mod] = []
                lowD_x[mod] = []

            total_loss_val = []
            total_dist_z = []
            total_dist_qz = []

            total_loss_rec = dict()
            r2_sc = dict()
            recon_x = dict()
            for mod in modalities:
                total_loss_rec[mod] = []
                r2_sc[mod] = []
                recon_x[mod] = []

            data_indx = []

            self.model.eval()
            with torch.no_grad():
                for i, (data_T, data_E, data_M, data_idx) in enumerate(data_loader):

                    data_idx = data_idx.type(torch.IntTensor)
                    
                    if self.gpu:
                        data_T = data_T.cuda(self.device)
                        data_E = data_E.cuda(self.device)
                        data_M = data_M.cuda(self.device)

                    trans_data = dict()
                    trans_data['T'] = data_T
                    trans_data['E'] = data_E
                    trans_data['M'] = data_M

                    recon, x_low, z_category, state, z_smp, mu, log_sigma, _ = self.model(x_T=trans_data['T'], x_E=trans_data['E'], x_M=trans_data['M'], temp=self.temp, eval=True, pruning_mask=self.pruning_mask)
                    loss, loss_arms, loss_joint, _, dist_z, d_qz, _, _ = self.model.loss(recon, trans_data, mu, log_sigma, z_category, z_smp)

                    data_indx.append(data_idx.numpy().astype(int))
                    total_loss_val.append(loss.data.item())
                    total_dist_z.append(dist_z.data.item())
                    total_dist_qz.append(d_qz.data.item())

                    for mod in modalities:
                        total_loss_rec[mod].append(loss_arms[mod].data.item())
                        recon_x[mod].append(recon[mod].detach().cpu().numpy())
                        if mod == 'M':
                            r2_sc[mod].append([])
                        else:
                            r2_sc[mod].append(r2_score(trans_data[mod].detach().cpu().numpy(), recon[mod].detach().cpu().numpy()))

                    for mod in ['T', 'ME']:
                        state_sample[mod].append(state[mod].cpu().detach().numpy())
                        state_mu[mod].append(mu[mod].cpu().detach().numpy())
                        state_var[mod].append(log_sigma[mod].cpu().detach().numpy())
                        z_encoder = z_category[mod].cpu().data.view(z_category[mod].size()[0], self.n_categories).detach().numpy()
                        z_prob[mod].append(z_encoder)
                        z_samp = z_smp[mod].cpu().data.view(z_smp[mod].size()[0], self.n_categories).detach().numpy()
                        z_sample[mod].append(z_samp)
                        lowD_x[mod].append(x_low[mod].detach().cpu().numpy())

                        for n in range(z_encoder.shape[0]):
                            state_cat[mod].append(np.argmax(z_encoder[n, :]) + 1)
                            prob_cat[mod].append(np.max(z_encoder[n, :]))

                        label_predict = []

                        for d in range(len(data_idx)):
                            z_cat = np.squeeze(z_encoder[d, :])
                            label_predict.append(np.argmax(z_cat) + 1)

                        predicted_label[mod].append(label_predict)

            for mod in ['T', 'ME']:
                state_sample[mod] = self.map_data(state_sample[mod], mod)
                state_mu[mod] = self.map_data(state_mu[mod], mod)
                try:
                    state_var[mod] = self.map_data(state_var[mod], mod) 
                except:
                    state_var[mod] = 0. * state_mu[mod]

                z_prob[mod] = self.map_data(z_prob[mod], mod)
                z_sample[mod] = self.map_data(z_sample[mod], mod)
                state_cat[mod] = self.map_data(state_cat[mod], mod)
                prob_cat[mod] = self.map_data(prob_cat[mod], mod)
                predicted_label[mod] = self.map_data(predicted_label[mod], mod)
                lowD_x[mod] = self.map_data(lowD_x[mod], mod)

            mean_total_loss_rec = dict()
            mean_recon_x = dict()
            for mod in modalities:
                mean_total_loss_rec[mod] = np.mean(np.array(total_loss_rec[mod]))
                mean_recon_x[mod] = self.map_data(recon_x[mod], mod)

            # save data
            data_file_id = self.folder + '/model/model_eval' #_pruning_' + str(len(prune_indx))

            self.save_file(data_file_id,
                        state_sample=state_sample,
                        state_mu=state_mu,
                        state_var=state_var,
                        state_cat=state_cat,
                        prob_cat=prob_cat,
                        r2_sc=r2_sc,
                        total_loss_rec=mean_total_loss_rec,
                        total_dist_z=np.mean(np.array(total_dist_z)),
                        total_dist_qz=np.mean(np.array(total_dist_qz)),
                        predicted_label=predicted_label,
                        data_indx=np.concatenate(data_indx),
                        recon_x=mean_recon_x,
                        z_prob=z_prob,
                        z_sample=z_sample,
                        prune_indx=prune_indx)

            outcome_dict = dict()
            outcome_dict['state_sample'] = state_sample
            outcome_dict['state_mu'] = state_mu
            outcome_dict['state_var'] = state_var
            outcome_dict['state_cat'] = state_cat
            outcome_dict['prob_cat'] = prob_cat
            outcome_dict['z_prob'] = z_prob
            outcome_dict['z_sample'] = z_sample
            outcome_dict['r2_sc'] = r2_sc
            outcome_dict['total_loss_rec'] = mean_total_loss_rec
            outcome_dict['total_dist_z'] = np.mean(np.array(total_dist_z))
            outcome_dict['total_dist_qz'] = np.mean(np.array(total_dist_qz))
            outcome_dict['predicted_label'] = predicted_label
            outcome_dict['data_indx'] = np.concatenate(data_indx)
            outcome_dict['prune_indx'] = prune_indx
            outcome_dict['recon_x'] = recon_x
            outcome_dict['lowD_x'] = lowD_x

            return outcome_dict
