import torch
import torchvision
import torchvision.transforms as transforms
import random
import sampler_NN
import argparse
from vae import *

def split_data(joint_img, num_proposals):
    cor_img = torch.zeros((num_proposals, joint_img.shape[1], 14, 14), device=joint_img.device)
    splits = torch.zeros(num_proposals, device=joint_img.device)
    for i in range(num_proposals):
        start_y = random.randint(0, 14)
        cor_img[i] = joint_img.detach().clone()[0, :, start_y:(start_y + 14), :14]
        splits[i] = start_y
    
    img = joint_img.detach()
    img[..., :, :14] = 0
    return img, cor_img, splits

def run_test(model, classifier, N, L_max, num_proposals, latent_dim, test_loader, device, baseline):
    exp_sampler = sampler_NN.Exp_Sampler()

    total_match = 0
    with torch.no_grad():
        listX = []
        listXhat = []
        listXhat_enc = []
        listside = []
        splits = []
        
        for i, data in enumerate(test_loader):
            joint_img, _ = data
            img, cor_img, split = split_data(joint_img.to(device), num_proposals)
            _, mu, logvar, _, _ = model(img, cor_img[:1])
        
            #truncate mu
            y = sampler_NN.gauss_gen(var=torch.tensor(1.0), B=1, N=N, dim=latent_dim, device=device) # proposal
            if baseline:
                logS_min = sampler_NN.logexp_rv(B=1, N=N, device=device)
                logS = logS_min.repeat(num_proposals, 1, 1)
            else:
                logS = sampler_NN.logexp_rv(B=num_proposals, N=N, device=device)
                logS_min, _ = torch.min(logS, dim=0, keepdim=True)
            hash_val = generate_M(B=1, N=N, L=L_max, device=device).unsqueeze(-1)
        
            mmnorm = MultivariateNormal(mu.flatten(), torch.diag(logvar.exp().flatten()))
        
            K_A, y_A, selected_M, ers_proba  = exp_sampler.encode(logS_min, y, mmnorm,
                                                                  mean_p=0.0, var_p=torch.tensor(1.0),
                                                                  hash_val=hash_val,  ers_selection=False,
                                                                  num_batch=1)
            K_B, y_B = exp_sampler.decode(classifier, logS, hash_val, y, cor_img, selected_M, num_batch=1)
        
            mses = torch.mean((y_A.expand(y_B.shape[0], -1) - y_B)**2, dim=-1)
            min_mse, min_mse_idx = torch.min(mses, dim=0)
            total_match += min_mse.item() < 0.01
            
            xhat_enc, _ = model.decode(y_A.expand(cor_img.shape[0], -1).float(), cor_img)
            xhat, _ = model.decode(y_B.float(), cor_img)

            listX.append(img[0, 0])
            listXhat.append(xhat[min_mse_idx, 0])
            listXhat_enc.append(xhat_enc[min_mse_idx, 0])
            listside.append(cor_img[min_mse_idx, 0])
            splits.append(split[min_mse_idx])

    listX=torch.stack(listX, dim=0)
    listXhat_enc = torch.stack(listXhat_enc, dim=0)
    listXhat = torch.stack(listXhat, dim=0)

    mseX_enc = ((listX - listXhat_enc)**2).mean().item()
    mseX = ((listX - listXhat)**2).mean().item()
    return mseX_enc, mseX

def test_config(L_max, num_proposals, latent_dim, test_loader, device, baseline):
    best_mseX = np.inf
    best_mseX_enc = np.inf
    best_beta = 0
    best_num_samples = 0

    betas = [0.15, 0.35, 0.55, 0.75, 0.95]
    num_samples = [7, 8, 9, 10, 11, 12]
    for b in betas:
        model = VAE(latent_dim=latent_dim).to(device)
        model.load_state_dict(torch.load(f'model/model_{int(b*100):2d}.pth', weights_only=True))
        classifier = NCEClassifier(u_dim=latent_dim).to(device)
        classifier.load_state_dict(torch.load(f'model/classifier_{int(b*100):2d}.pth', weights_only=True))
        model.eval()
        classifier.eval()

        for n in num_samples:
            if n < int(np.log2(L_max)):
                continue
            mseX_enc, mseX = run_test(model, classifier, 2**n, L_max, num_proposals, latent_dim, 
                                      test_loader, device, baseline)
            print(f'K={num_proposals}  log(Lmax)={int(np.log2(L_max))}  log(N)={n:2d}  beta={b:.2f}  '
                  f'DecoderMSE={mseX:.4e}  EncoderMSE={mseX_enc:.4e}', flush=True)
            if mseX < best_mseX:
                best_mseX = mseX
                best_mseX_enc = mseX_enc
                best_beta = b
                best_num_samples = n

    return best_mseX, best_mseX_enc, best_beta, best_num_samples

def test_loop(num_proposals, device, baseline):
    # Define the transformation to apply to the dataset
    transform = transforms.Compose([
        transforms.ToTensor(),  # Convert the images to tensors
        transforms.Normalize((0.5,), (0.5,))  # Normalize the pixel values to the range [-1, 1]
    ])
    
    # Load the test dataset
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)

    latent_dim = 4
    nbits = [2, 3, 4, 5, 6, 7, 8]
    for l in nbits:
        best_mseX, best_mseX_enc, best_beta, best_num_samples = test_config(2**l, num_proposals, latent_dim, 
                                                                            test_loader, device, baseline)
        print(f'K={num_proposals}  log(Lmax)={l}  log(N)={best_num_samples:2d}  beta={best_beta:.2f}  '
              f'DecoderMSE={best_mseX:.4e}  EncoderMSE={best_mseX_enc:.4e}  (BEST)\n', flush=True)

if __name__ == '__main__':
    num_proposals = [1, 2, 3, 4, 5, 6, 7, 8]
    for n in num_proposals:
        print('Testing normal scheme\n=====================', flush=True)
        test_loop(n, 'cuda', baseline=False)
        print('\nTesting baseline scheme\n=======================', flush=True)
        test_loop(n, 'cuda', baseline=True)
        print('')
    
