
from model_utils import *

if torch.cuda.is_available():
    DEVICE='cuda'
else:
    DEVICE='cpu'

def train_model(model, opt, data_y, data_x, cv_data_y, cv_data_x):
            # initialize variables to store losses
            kldz_loss = []
            z_mmdloss = []
            klds_loss = []
            reconx_loss = []
            recony_loss = []
            G_combined_loss = []

            reconx_loss_eva = []
            recony_loss_eva = []

            w_list_y=[] # WD between real y and fake y
            w_list_eva_y=[]
            w_list_x=[]# WD between real x and fake x
            w_list_eva_x=[]
            w_list_pair=[] # WD between real x and real y
            w_list_synth_pair=[] # WD between synth x and synth y
            nROI=opt['nROI']
            N = max(data_x.shape[0],data_y.shape[0])
            batch_size = opt['batch_size']
            n_batches=int(np.ceil(N/batch_size))

            max_epochs=opt['epochs']
            reduce = 'mean'

            for epoch in range(max_epochs):
                kldz_loss_batch = 0 
                recon_x_batch = 0
                recon_y_batch = 0
                klds_batch = 0
                G_combined_batch_loss = 0
                z_mmd_batch=0               
                    
                print("")
                print('======== Epoch {:} / {:} ========'.format(epoch+1, opt['epochs']))
                print('Training...')

                # Example: data_x has M samples, data_y has N samples
                MX = data_x.shape[0]
                MY = data_y.shape[0]

                # Shuffle indices separately
                idx_x = torch.randperm(MX)
                idx_y = torch.randperm(MY)

                # Find the smaller size
                min_size = min(MX, MY)

                # First, pair up to min_size
                paired_x = data_x[idx_x[:min_size]]
                paired_y = data_y[idx_y[:min_size]]

                paired_data = torch.cat([paired_x, paired_y], dim=1)

                # Handle leftovers
                if MX > MY:
                    leftover_x = data_x[idx_x[min_size:]]
                    leftover_y = data_y[idx_y[:MX-MY]]
                    leftover_data = torch.cat([leftover_x, leftover_y], dim=1)
                elif MY > MX:
                    leftover_y = data_y[idx_y[min_size:]]
                    leftover_x = data_x[idx_x[:MY-MX]]
                    leftover_data = torch.cat([leftover_x, leftover_y], dim=1)
                else:
                    leftover_data = None

                # Final dataset
                if leftover_data is not None:
                    data = torch.cat((paired_data, leftover_data), dim=0)
                else:
                    data = paired_data
                        
                print("")
                print('======== Epoch {:} / {:} ========'.format(epoch+1, opt['epochs']))
                print('Training...')
                t0 = time.time()
                for i in range(n_batches):
                    start = i*batch_size
                    if i != n_batches-1:
                        end = (i+1)*batch_size
                        real_x = data[start:end,0:nROI]
                        real_y = data[start:end,nROI:2*nROI]
                    else:
                        real_x = data[start:,0:nROI]
                        real_y = data[start:,nROI:2*nROI]

                    if batch_size == 1:
                        real_x = real_x.reshape(1,-1)
                        real_y = real_y.reshape(1,-1)
                    real_x = real_x.to(opt['DEVICE']).type(torch.float32)
                    real_y = real_y.to(opt['DEVICE']).type(torch.float32)

                    #train
                    losses = model.train_instance(real_x, real_y, reduction='mean')

                    kldz_loss_batch += losses['kldz_loss'] # kl divergence on common  z
                    G_combined_batch_loss += losses['Combined_loss']
                    recon_y_batch += losses['Generator_loss_list'][1] # recontsruction loss x
                    recon_x_batch += losses['Generator_loss_list'][0] # reconstruction loss y
                    klds_batch += losses['Generator_loss_list'][2] # kl divergence on salient
                    z_mmd_batch += losses['Generator_loss_list'][3] # MMD on z
                    # Progress update every 40 batches.
                #if i % 40 == 0 and not i == 0:
                # Calculate elapsed time in minutes.
                #elapsed = format_time(time.time() - t0)
                # Report progress.
                #print('  Batch {:>5,}  of  {:>5,}. \nLosses {}'.format(i, n_batches, losses))

                kldz_loss.append(kldz_loss_batch/n_batches) # kl divergence on common
                klds_loss.append(klds_batch/n_batches) # kl divergence on salient
                z_mmdloss.append(z_mmd_batch/n_batches) # d latent 
                reconx_loss.append(recon_x_batch /n_batches) # recon c loss

                G_combined_loss.append(G_combined_batch_loss/n_batches) # generator combined loss
                recony_loss.append(recon_y_batch/n_batches) # recon y

                print("")
                print("  Average loss: {0:.2f}".format(G_combined_loss[-1]))
                print("  Average recon x loss: {0:.2f}".format(reconx_loss[-1]))
                print("  Average recon y loss: {0:.2f}".format(recony_loss[-1]))
                print("  Training epoch took: {:}".format(format_time(time.time() - t0)))

                # model eval
                model.G_forward.eval()
                with torch.no_grad():
                    fake_x, fake_y, mu_sx, logsigma_sx, zx, zy,_,_,_,_ = model.G_forward(data_x, data_y)
                    fake_x_eva, fake_y_eva, mu_sxeva, logsigma_sxeva, zx_eva, zy_eva,mu_evax,_,mu_evay,_ = model.G_forward(cv_data_x, cv_data_y)
                    
                    reconx_loss_eva.append(float(F.mse_loss(fake_x_eva, cv_data_x, reduction=reduce)*fake_x_eva.shape[-1]))
                    recony_loss_eva.append(float(F.mse_loss(fake_y_eva, cv_data_y, reduction=reduce)*fake_y_eva.shape[-1]))

                    zcontent = torch.randn(400,opt['nLatent_z'], device='cuda')
                    s_cont_empty =  torch.zeros(400,opt['nLatent_zd'], device='cuda')
                    s_cont =  torch.randn(400,opt['nLatent_s'], device='cuda')
                    synth_y = model.G_forward.modeldec(torch.concat([zcontent,s_cont_empty],dim=1))
                    synth_x = model.G_forward.modeldec(torch.concat([zcontent,s_cont],dim=1))

                    w_list_y.append(eval_w_distances_forward(data_x,fake_y, independent='False'))
                    w_list_eva_y.append(eval_w_distances_forward(cv_data_y,synth_y,  independent='False'))
                    w_list_x.append(eval_w_distances_forward(data_x,fake_x,  independent='False'))
                    w_list_eva_x.append(eval_w_distances_forward(cv_data_x,synth_x,  independent='False'))
                    w_list_pair.append(eval_w_distances_forward(data_y,data_x, independent='False'))
                    w_list_synth_pair.append(eval_w_distances_forward(synth_y,synth_x, independent='False'))

                
                if epoch%50==0:
                    for i in range(8):
                        ax=plt.subplot(2,4,i+1)
                        ax.scatter(cv_data_x[:,0].to('cpu').numpy(),cv_data_x[:,i].to('cpu').numpy(), alpha=0.2)
                        ax.scatter(cv_data_y[:,0].to('cpu').numpy(),cv_data_y[:,i].to('cpu').numpy(), alpha=0.2)
                        ax.scatter(fake_y_eva[:,0].to('cpu').numpy(),fake_y_eva[:,i].to('cpu').numpy(), alpha=0.2)
                        ax.scatter(fake_x_eva[:,0].to('cpu').numpy(),fake_x_eva[:,i].to('cpu').numpy(), alpha=0.2)
                    
                    plt.legend(['real PT','real HC','fake HC', 'fake PT'])
                    plt.show() 
            
            loss_list = {'recon_x': [reconx_loss, reconx_loss_eva], 'recon_y': [recony_loss, recony_loss_eva], 'kl_loss': [kldz_loss, klds_loss], 'mmd_loss':z_mmdloss, 'w_list_y':[w_list_y, w_list_eva_y], 'w_list_x':[w_list_x, w_list_eva_x],'w_list_pair':[w_list_pair, w_list_synth_pair]}
            return model, loss_list 



# Contrastive VAE - z is also a distribution
# Include ROIs of choice
df_roi = pd.read_csv('data/ROI_Dictionary.csv')
col_roi_ind = df_roi[df_roi.ROI_LEVEL=='SINGLE'].ROI_INDEX.tolist()
col_roi_ind = list(set(col_roi_ind)-set([4,11,49,50,51,52,46,63,64]))# exclude ventricle and csf rois
print(len(col_roi_ind))
col_features = ['_'.join(['H_ROI_Volume', str(j)])  for j  in  col_roi_ind]

cn = pd.read_csv('data/CVAE/train_cn_10to30.csv').loc[:,col_features].to_numpy()
cn_pt =pd.read_csv('data/CVAE/train_pt_10to30.csv').loc[:,col_features].to_numpy()

cn_eva =pd.read_csv('data/CVAE/cval_cn_10to30.csv').loc[:,col_features].to_numpy()
cn_pt_eva =pd.read_csv('data/CVAE/cval_pt_10to30.csv').loc[:,col_features].to_numpy()

opt={'nROI':len(col_features),##10,#139, # number of ROIs
    'nLatent_z':5,
    'nLatent_s':5,
    'mmd_lambda':10,# hyperparametr to control mmd loss
    'kld_lambda':1,# hyperparametr to control kld z loss
    'lr':0.001, # learning rate
    'beta1':0.5,#0.5,  beta1 hyperparameter in Adam optimizer
    'DEVICE':DEVICE, # use 'cpu' or 'cuda'
    'epochs':500, # total number of epochs
    'decay_epochs': 1, # epoch number when learning rate decay should begin
    'scheduler': False, #True to set learning rate scheduler
    'batch_size':128
    }

data_x = torch.from_numpy(cn_pt.copy()).to(DEVICE).type(torch.float32) 
data_y = torch.from_numpy(cn.copy()).to(DEVICE).type(torch.float32)

data_x_eva = torch.from_numpy(cn_pt_eva.copy()).to(DEVICE).type(torch.float32)  
data_y_eva = torch.from_numpy(cn_eva.copy()).to(DEVICE).type(torch.float32) 

# define model
model = ContrastiveVAE()
model.create(opt=opt)
model, loss_list = train_model(model, opt, data_x,data_y,data_x_eva,data_y_eva)



checkpoint = {
            'G_forward': model.G_forward.state_dict(),
            'optimizer_G':model.optimizer_G.state_dict(),
            'loss_list': loss_list
            }
checkpoint.update(opt)
torch.save(checkpoint, os.path.join('weights', 'ContrastiveVAE_modelchkpt.pt'))


plt.plot(loss_list['recon_x'][0])
plt.plot(loss_list['recon_x'][1])
plt.plot(loss_list['recon_y'][0])
plt.plot(loss_list['recon_y'][1])
plt.legend(['recon x','recon eva x','recon y','recon eva y'])
plt.show()

plt.plot(loss_list['w_list_x'][0])
plt.plot(loss_list['w_list_x'][1])
plt.plot(loss_list['w_list_y'][0])
plt.plot(loss_list['w_list_y'][1])
plt.legend(['wd x','wd eva x','wd  y','wd eva y'])
plt.show()

plt.plot(loss_list['w_list_pair'][0])
plt.plot(loss_list['w_list_pair'][1])
plt.legend(['wd pair','wd eva pair'])
plt.show()


plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.plot(loss_list['kl_loss'][0])
plt.legend(['kl divergence z'])
plt.subplot(1,3,2)
plt.plot(loss_list['kl_loss'][1])
plt.legend(['kl divergence s'])
plt.show()

plt.plot(loss_list['mmd_loss'])
plt.legend(['mmd loss'])
plt.show()

plt.show()


