import numpy as np
import tensorflow as tf
import os
import inspect
from scipy import stats
from aevisnet import ConvNet
from data import Dataset, load_data
import matplotlib.pyplot as plt
import better_exceptions
import scipy.io as sio
from tqdm import tqdm 

from skimage import measure
def p(test_,responses_test_avg):
    corr = np.array([stats.pearsonr(yhat, y)[0] if np.std(yhat) > 1e-5 and np.std(y) > 1e-5 else 0 for yhat, y in zip(test_.T, responses_test_avg.T)])
    corr = corr.mean()
    return corr

def fit(net,
        learning_rate=0.001, max_iter=10000, batch_size=256,val_steps=50,
        early_stopping_steps=10):
    '''Fit CNN model.
    
    Parameters:
        data:                  Dataset object (see load_data())
        filter_sizes:          Filter sizes (list containing one number per conv layer)
        out_channels:          Number of output channels (list; one number per conv layer)
        strides:               Strides (list; one number per conv layer)
        paddings:              Paddings (list; one number per conv layer; VALID|SAME)
        smooth_weights:        Weights for smoothness regularizer (list; one number per conv layer)
        sparse_weights:        Weights for group sparsity regularizer (list; one number per conv layer)
        readout_sparse_weight: Sparisty of readout weights (scalar)
        learning_rate:         Learning rate (default: 0.001)   
        max_iter:              Max. number of iterations (default: 10000)
        val_steps:             Validation interval (number of iterations; default: 50)
        early_stopping_steps:  Tolerance for early stopping. Will stop optimizing 
            after this number of validation steps without decrease of loss.

    Output:
        cnn:                   A fitted ConvNet object
    '''

    for lr_decay in range(5):
        training = net.train(max_iter=max_iter,
                             val_steps=val_steps,
                             early_stopping_steps=early_stopping_steps,
                             batch_size = batch_size,
                             learning_rate=learning_rate)

        for (i,t, (total_loss,recon,poisson,_,_,res,pred)) in training:
            recons=recon
            #print('Step %d | Time %d | Loss: %.3f | Recon_img: %.3f | poisson: %.3f | var(pred): %.3f |P: %.3f |' % \
            #      (i,t, total_loss,recon,poisson,np.mean(np.var(pred, axis=0)),p(pred,res)))
            #print(k)
        learning_rate /= 2.0
        #print('Reducing learning rate to %f' % learning_rate)
    print('Done fitting')


    return net

def cal_performance(src_imgs, dst_imgs):
    src_imgs = src_imgs.astype('float32')
    dst_imgs = dst_imgs.astype('float32')

    # src_imgs true

    img_num = src_imgs.shape[0]
    all_mse = np.zeros(img_num)
    all_psnr = np.zeros(img_num)
    all_ssim = np.zeros(img_num)

    for i in range(img_num):
        all_mse[i] = measure.compare_mse(src_imgs[i], dst_imgs[i])
        all_psnr[i] = measure.compare_psnr(src_imgs[i], dst_imgs[i])
        all_ssim[i] = measure.compare_ssim(src_imgs[i], dst_imgs[i], multichannel = True)

    return np.mean(all_mse), np.mean(all_psnr), np.mean(all_ssim)


def plt_imgs(bgra_img_test,name):

    figures = "fig_s1"
    if not os.path.exists(figures):
        os.mkdir(figures)
    camp='gray'
    n = 5
    m = 10
    plt.figure(figsize=(m,n))  
    resolution =31 

    for j in range(n):
        for i in range(m):
            # display original images
            ax = plt.subplot(n, m, i+j*m  + 1)
            #plt.imshow(X_reconstructed_mu_1[num[i]].reshape(resolution ,resolution,3))
            plt.imshow(bgra_img_test[i+j*m].reshape(resolution ,resolution),cmap='gray')
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

    plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace =0, hspace =0)
    plt.savefig(os.path.join(figures,name+'.pdf'), dpi=300)
    plt.close()

    
def evaluate(net):
    '''Evaluate CNN model.
    
    Parameters:
        net: A fitted ConvNet object
    
    Outputs:
        results: 
            A dictionary containing the following evaluation metrics:
                train_loss: Training loss
                val_loss:   Validation loss
            
            The following statistics are all evaluated on the test set:
                mse:        Mean-squared error for each cell
                avg_mse:    Average mean-squared error across cells
                corr:       Correlation between prediction and observation for each cell
                avg_corr:   Average correlation across cells
                var:        Variance of prediction for each cell
                avg_var:    Average variance across cells
                ve:         Variance explained for each cell
                avg_ve:     Average variance explained across cells
                eve:        Explainable variance explained for each cell
                    (Excludes an estimate of the observation noise. Be careful: this
                    quantity is not very reliable and needs to be taken with a grain
                    of salt)
                avg_eve:    Average explainable variance explained across cells
                nnp:        Normalized noise power (see Antolik et al. 2016)
    '''
    results = dict()
    train_loss = 0
    train_imgs_mse = 0
    batch_size = 288
    net.data.next_epoch()
    for i in range(5):
        imgs_train, responses_train = net.data.minibatch(batch_size)
        train_loss += net.eval_test(images=imgs_train, responses=responses_train)[0]
    results['train_loss'] = train_loss / 5
    results['train_imgs_mse'] = train_imgs_mse / 5

    imgs_train, responses_train = net.data.minibatch(batch_size)
    train_results = net.eval_test(images=imgs_train, responses=responses_train)

    results['train_imgs_r'] = train_results[3]
    results['train_imgs_o'] = imgs_train


    imgs_val, responses_val = net.data.val()
    val_results = net.eval_test(images=imgs_val, responses=responses_val)
    results['val_loss'] = val_results[0]
    results['val_imgs_mse'] = val_results[1]



    imgs_test, responses_test = net.data.test(averages=False)
    responses_test_avg = responses_test.mean(axis=0)
    test_results = net.eval_test(images=imgs_test, responses=responses_test_avg)
    results['test_loss'] = test_results[0]
    results['test_imgs_mse'] = test_results[1]
    results['test_imgs_r'] = test_results[3]
    results['test_imgs_o'] = imgs_test
    
    per = cal_performance(imgs_test, test_results[3])
    results['test_mse'] = per[0]
    results['test_psnr'] = per[1]
    results['test_ssim'] = per[2]

    result = net.eval(images=imgs_test, responses=responses_test_avg)


    #results['img_mse'] = np.mean((result[-1] - responses_test_avg) ** 2, axis=0)
    results['mse'] = np.mean((result[-1] - responses_test_avg) ** 2, axis=0)
    results['avg_mse'] = results['mse'].mean()
    results['corr'] = np.array([stats.pearsonr(yhat, y)[0] if np.std(yhat) > 1e-5 and np.std(y) > 1e-5 else 0 for yhat, y in zip(test_results[-1].T, responses_test_avg.T)])
    #print(results['corr'])
    results['avg_corr'] = results['corr'].mean()
    #print(results['avg_corr'])
    #字典中的key值即为csv中列名

    results['var'] = result[-1].var(axis=0)
    results['avg_var'] = results['var'].mean()
    results['ve'] = 1 - results['mse'] / responses_test_avg.var(axis=0)
    results['avg_ve'] = results['ve'].mean()
    reps, _, num_neurons = responses_test.shape
    obs_var_avg = (responses_test.var(axis=0, ddof=1) / reps).mean(axis=0)
    total_var_avg = responses_test.mean(axis=0).var(axis=0, ddof=1)
    results['eve'] = (total_var_avg - results['mse']) / (total_var_avg - obs_var_avg)
    results['avg_eve'] = results['eve'].mean()
    obs_var = (responses_test.var(axis=0, ddof=1)).mean(axis=0)
    total_var = responses_test.reshape([-1, num_neurons]).var(axis=0, ddof=1)
    results['nnp'] = obs_var / total_var
    return results

    
def main_dae(read_ae_conv,x,fully_connected_readout,fixed_rfs):
    import os
    region_num = x
    ae_name = 'dae'
    
    log_dir = 'region_'+str(region_num)
    log_results = 'results_mession8/'+'region_'+str(region_num)+'/'


    log_dir1 = os.path.join('train_logs', log_dir)
    if not os.path.exists(log_dir1):
            os.makedirs(log_dir1)

    if not os.path.exists(log_results):
        os.makedirs(log_results)

    data = load_data(region_num=region_num)
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    cnn = ConvNet(data, log_dir=log_dir, log_hash='manual')
    cnn.build_dae(read_ae_conv,fully_connected_readout,fixed_rfs)

    if fully_connected_readout:
        vismodelling_name='fc'
        print('FC--------------------- FC')
    else:
        if fixed_rfs:
            print('FF--------------------- FF')
            vismodelling_name='ff'
        else:
            print('FR--------------------- FR')
            vismodelling_name='fr'

    train = fit(cnn)
    results = evaluate(cnn)
    #results = evaluate(cnn)
    sio.savemat(log_results+'/exp1_'+ae_name+vismodelling_name+'_ae_'+str(read_ae_conv[0])+'read_'+str(read_ae_conv[1])+'conv_'+str(read_ae_conv[2])+'.mat',results)
    print('Training loss: {:.4f} | Validation loss: {:.4f} | test loss: {:.4f} | test_imgs_mse: {:.4f} | test_imgs_psnr: {:.4f} | test_imgs_ssim: {:.4f} | PCC: {:.4f} '.format(
        results['train_loss'], results['val_loss'], results['test_loss'],results['test_mse'], results['test_ssim'], results['test_psnr'],results['avg_corr']))

def main_vae(read_ae_conv,x,fully_connected_readout,fixed_rfs):
    import os
    region_num = x
    ae_name = 'vae'
    
    log_dir = 'region_'+str(region_num)
    log_results = 'results_mession8/'+'region_'+str(region_num)+'/'


    log_dir1 = os.path.join('train_logs', log_dir)
    if not os.path.exists(log_dir1):
            os.makedirs(log_dir1)

    if not os.path.exists(log_results):
        os.makedirs(log_results)

    data = load_data(region_num=region_num)
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    cnn = ConvNet(data, log_dir=log_dir, log_hash='manual')
    cnn.build_vae(read_ae_conv,fully_connected_readout,fixed_rfs)

    if fully_connected_readout:
        vismodelling_name='fc'
        print('FC--------------------- FC')
    else:
        if fixed_rfs:
            print('FF--------------------- FF')
            vismodelling_name='ff'
        else:
            print('FR--------------------- FR')
            vismodelling_name='fr'

    train = fit(cnn)
    results = evaluate(cnn)
    #results = evaluate(cnn)
    sio.savemat(log_results+'/exp1_'+ae_name+vismodelling_name+'_ae_'+str(read_ae_conv[0])+'read_'+str(read_ae_conv[1])+'conv_'+str(read_ae_conv[2])+'.mat',results)
    print('Training loss: {:.4f} | Validation loss: {:.4f} | test loss: {:.4f} | test_imgs_mse: {:.4f} | test_imgs_psnr: {:.4f} | test_imgs_ssim: {:.4f} | PCC: {:.4f} '.format(
        results['train_loss'], results['val_loss'], results['test_loss'],results['test_mse'], results['test_ssim'], results['test_psnr'],results['avg_corr']))


def main_vqvae(read_ae_conv,x,fully_connected_readout,fixed_rfs):
    import os
    region_num = x
    ae_name = 'vqvae'
    
    log_dir = 'region_'+str(region_num)
    log_results = 'results_mession8/'+'region_'+str(region_num)+'/'


    log_dir1 = os.path.join('train_logs', log_dir)
    if not os.path.exists(log_dir1):
            os.makedirs(log_dir1)

    if not os.path.exists(log_results):
        os.makedirs(log_results)

    data = load_data(region_num=region_num)
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    cnn = ConvNet(data, log_dir=log_dir, log_hash='manual')
    cnn.build_vqvae(read_ae_conv,fully_connected_readout,fixed_rfs)

    if fully_connected_readout:
        vismodelling_name='fc'
        print('FC--------------------- FC')
    else:
        if fixed_rfs:
            print('FF--------------------- FF')
            vismodelling_name='ff'
        else:
            print('FR--------------------- FR')
            vismodelling_name='fr'

    train = fit(cnn)
    results = evaluate(cnn)
    #results = evaluate(cnn)
    sio.savemat(log_results+'/exp1_'+ae_name+vismodelling_name+'_ae_'+str(read_ae_conv[0])+'read_'+str(read_ae_conv[1])+'conv_'+str(read_ae_conv[2])+'.mat',results)
    print('Training loss: {:.4f} | Validation loss: {:.4f} | test loss: {:.4f} | test_imgs_mse: {:.4f} | test_imgs_psnr: {:.4f} | test_imgs_ssim: {:.4f} | PCC: {:.4f} '.format(
        results['train_loss'], results['val_loss'], results['test_loss'],results['test_mse'], results['test_ssim'], results['test_psnr'],results['avg_corr']))


if __name__ == "__main__":
    


    ex1 = [[0.0,1.0,0],[0.0,1.0,1],[0.0,1.0,2],[0.0,1.0,3],
           [0.0,1.0,1],[0.0,1.0,1],[0.0,1.0,1],[0.0,1.0,1],
           [0.0,1.0,2],[0.0,1.0,2],[0.0,1.0,2],[0.0,1.0,2],
           [0.0,1.0,3],[0.0,1.0,3],[0.0,1.0,3],[0.0,1.0,3]]

    for i in tqdm(range(len(ex1))):

        # fc fully_connected_readout=True,fixed_rfs=False
        # ff fully_connected_readout=False,fixed_rfs=True
        # fr fully_connected_readout=False,fixed_rfs=False
        #main_dae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_vqvae(ex1[i],1,fully_connected_readout=False,fixed_rfs=True)
        #main_dae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)

        #main_vae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_vae(ex1[i],1,fully_connected_readout=False,fixed_rfs=True)
        #main_vae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)

        #main_vqvae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_vqvae(ex1[i],1,fully_connected_readout=False,fixed_rfs=True)
        #main_vqvae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)

        # fc fully_connected_readout=True,fixed_rfs=False
        # ff fully_connected_readout=False,fixed_rfs=True
        # fr fully_connected_readout=False,fixed_rfs=False
        #main_dae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_dae(ex1[i],2,fully_connected_readout=False,fixed_rfs=True)
        #main_dae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)

        #main_vae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_vae(ex1[i],2,fully_connected_readout=False,fixed_rfs=True)
        #main_vae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)

        #main_vqvae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_vqvae(ex1[i],2,fully_connected_readout=False,fixed_rfs=True)
        #main_vqvae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)

        # fc fully_connected_readout=True,fixed_rfs=False
        # ff fully_connected_readout=False,fixed_rfs=True
        # fr fully_connected_readout=False,fixed_rfs=False
        #main_dae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_dae(ex1[i],3,fully_connected_readout=False,fixed_rfs=True)
        #main_dae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)

        #main_vae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_vae(ex1[i],3,fully_connected_readout=False,fixed_rfs=True)
        #main_vae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)

        #main_vqvae(ex1[i],j,fully_connected_readout=True,fixed_rfs=True)
        main_vqvae(ex1[i],3,fully_connected_readout=False,fixed_rfs=True)
        #main_vqvae(ex1[i],j,fully_connected_readout=False,fixed_rfs=False)










