from sympy import var
from noise import *
from pathlib import Path 
import json
from data import *
from trackers import *
from main import *
import random
import torchvision

from networks.conditioning import *

import lpips


# Seed parameterers
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def load_args(name, step="last", log=True, dataloaders=False):
    """ Load an experiment with a given name. step can be an integer, "best", or "last" (default). """
    exp_dir = Path("models") / name

    with open(exp_dir / "args.json") as f:
        args_dict = json.load(f)

    args_dict['num_workers'] = 8
    return args_dict


def load_exp(name, step="last", log=True, dataloaders=False):
    """ Load an experiment with a given name. step can be an integer, "best", or "last" (default). """
    exp_dir = Path("models") / name

    with open(exp_dir / "args.json") as f:
        args_dict = json.load(f)

    args_dict['size_network'] = "small"
    args_dict['adaptive_scale'] = False

    ctx = TrainingContext(**args_dict, step=step, key_remap=None, seed=None, dataloaders=dataloaders, writer=False)
    if log:
        print(f"{name}: retrieved model at step {ctx.step}")

    # Disable DataParallel (needed for Hessian computation)
    # ctx.model.network = ctx.model.network.module

    # Put in eval mode and disable gradients with respect to all parameters.
    ctx.model.eval()
    for p in ctx.model.parameters():
        p.requires_grad = False

    # Normalize energies.
    # ctx.network.network.log_normalization_constant = ctx.test_perf.log_normalization_constant

    return ctx

def get_sigmas(sigma_begin, sigma_end, num_classes, device, sigma_dis):
    if sigma_dist == 'geometric':
        sigmas = torch.tensor(
            np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end),
                               num_classes))).float().to(device)
    elif sigma_dist == 'uniform':
        sigmas = torch.tensor(
            np.linspace(sigma_begin, sigma_end, num_classes)
        ).float().to(device)

    else:
        raise NotImplementedError('sigma distribution not supported')

    return sigmas


def compute_sigma(y, model, inp_mask_type = 'box', type_cov = 'spatial'):
    num_variances = 1 #30
    psnr_min = 90
    psnr_max = -30
    psnrs = torch.linspace(psnr_min, psnr_max, num_variances , device=default_ctx.device)  # (L,), steps of 7.5dB
    noise_levels = DenoisingError(dataset_info=default_ctx.dataset_info, psnr=psnrs).to_noise_level().variance  # (L,)
    noise_level_2 = 1e-9
    H = y.shape[-1]

    output_reg_d2 =  torch.zeros(num_variances)
    min_energy = np.zeros(H-1)
    estimated_var = np.zeros(H-1)

    box_idx = 0
    if type_cov == 'spatial':
        type_cov_tensor = torch.tensor(0)
    elif type_cov == 'freq':
        type_cov_tensor = torch.tensor(1)
    with torch.no_grad():
        for box in range(0, H - 1):
            for idx in range(num_variances):
                covariance = spatial_corr_covariance_testing(spatial_size=H, box_size=box, var_box=10**2, device=device, var_clean = noise_level_2, inp_mask_type=inp_mask_type, half_box_size = box)
                output_reg_d2[idx] = model.network(y, covariance.get_matrix()[None,None,:,:], type_cov_list =type_cov_tensor)
            
            min_energy[box_idx] = np.min(output_reg_d2.cpu().numpy())
            estimated_var[box_idx] = noise_levels[np.argmin(output_reg_d2.cpu().numpy()).item()].cpu().numpy()
            box_idx = box_idx + 1
    
    return np.argmin(min_energy)

def PC_sampler(x_mod, scorenet, sigmas, 
               n_steps_each=200, step_lr=0.000008,
                batch_size = 1, final_only=False, verbose=False, 
                denoise=True, data_score = False,
                inp_mask_type="half", box_size = 12, temp = 1, snr = 0.1):
    
    images = []
    var_clean = 1e-9
    H = x_mod.shape[-1]
    covariance_id = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=1, device=device, var_clean = 0, inp_mask_type=inp_mask_type, half_box_size = box_size)]
    # covariance_id = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=0, device=device, var_clean = 1, inp_mask_type=inp_mask_type, half_box_size = box_size)]
    iter_print_img = 0

    with torch.no_grad():
        for c, sigma in enumerate(sigmas[:-2]):

            ## Predictor step
            sigma_curr = sigmas[c]
            sigma_next = sigmas[c+1]
            

            covariance = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=sigma_curr**2, device=device, var_clean = var_clean, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
            # covariance = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=var_clean, device=device, var_clean = sigma_curr**2, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
            noise_level = NoiseLevel(variance=sigma_curr**2)
            input = ModelInput(noisy=x_mod, noise_level=noise_level, covariance=covariance)
            model_output = scorenet.forward(input, create_graph = False)
            grad = -model_output.data_score

            diff = sigma_curr**2 - sigma_next**2
            covariance_step = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=diff, device=device, var_clean = var_clean, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
            covariance_step_noise = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=(diff * sigma_next ** 2) / sigma_curr ** 2, device=device, var_clean = var_clean, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
            # covariance_step = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=var_clean, device=device, var_clean = diff, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
            # covariance_step_noise = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=var_clean, device=device, var_clean = (diff * sigma_next ** 2) / sigma_curr ** 2, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
            noise = torch.randn_like(x_mod)
            x_mod = x_mod + covariance_step[0].apply_power(grad, p=1) + covariance_step_noise[0].apply_power(noise, p=0.5)

            # Corrector steps
            for s in range(n_steps_each):
                covariance = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=sigma_next**2, device=device, var_clean = var_clean, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
                # covariance = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=var_clean, device=device, var_clean = sigma_next**2, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
                noise_level = NoiseLevel(variance=(sigma_next**2))                
                input = ModelInput(noisy=x_mod, noise_level=noise_level, covariance=covariance)
                model_output = scorenet.forward(input, create_graph = False)
                grad = -model_output.data_score

    
                noise = torch.randn_like(x_mod)
                grad_norm = torch.norm(covariance_id[0].apply_power(grad, p = 1).view(grad.shape[0], -1), dim=-1).mean()
                # noise_norm = np.sqrt(box_size * box_size * 3)
                noise_norm = torch.norm(covariance_id[0].apply_power(noise, p = 1).view(noise.shape[0], -1), dim=-1).mean()
                step_size = (snr * noise_norm / grad_norm) ** 2 * 2
                covariance_step = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=step_size, device=device, var_clean = var_clean, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
                # covariance_step = [spatial_corr_covariance_testing(spatial_size=H, box_size=box_size, var_box=var_clean, device=device, var_clean = step_size, inp_mask_type=inp_mask_type, half_box_size = box_size)] * batch_size
                x_mod = x_mod + covariance_step[0].apply_power(grad, p=1) + np.sqrt(2 * temp) * covariance_step[0].apply_power(noise, p=0.5)


                if not final_only:
                    images.append(x_mod.to('cpu'))
                                
                if iter_print_img % 200 == 0:
                    samples = torch.clamp(x_mod, 0.0, 1.0)
                    nrow = 4
                    grid = torchvision.utils.make_grid(samples, nrow=nrow, padding=2, normalize=False)
                    grid_np = grid.permute(1, 2, 0).cpu().numpy()
                     
                    # Display the grid
                    plt.figure(figsize=(7, 7)) # Adjust figure size as needed
                    plt.imshow(grid_np)
                    plt.axis('off') # Hide axes
                    plt.title("Generated images - Energy")
                    plt.savefig(f'samples/sampling_intermediate_iter_{iter_print_img}.pdf')
                    print(f"Saved intermediate sampling image at iteration {iter_print_img}")
                iter_print_img = iter_print_img + 1


            if n_steps_each == 0:
                if iter_print_img % 500 == 0:
                    samples = torch.clamp(x_mod[0], 0.0, 1.0)
                    plt.figure(figsize=(2, 2))
                    plt.imshow(samples.cpu().numpy().transpose(1, 2, 0))
                    plt.axis('off')
                    plt.show()
                iter_print_img = iter_print_img + 1

        if final_only:
            return [x_mod.to('cpu')]
        else:
            return images

if __name__ == "__main__":

    ## Load models
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=False, choices=['CIFAR10', 'Celeba'])

    args_cmd = parser.parse_args()
    print(args_cmd)
    # args = load_args("multigpu/all_together/energy_song_score_anisoEmb_groupNorm_lambda1_mult_lr2e-4_lrdecay50000_1000warmup_d2_correct", step = "best")
    if args_cmd.dataset == "CIFAR10":
        args = load_args("multigpu/all_together/energy_songSmall_dual_truncatedFreq", step = "best")
        step = "last"
        ctxs = {
            "Energy-Dual": load_exp("multigpu/all_together/energy_songSmall_dual_truncatedFreq", step = step),
            # "Energy-Single": load_exp("multigpu/all_together/energy_songSmall_single_truncatedFreq", step = step),
            # "Energy-Dual": load_exp("multigpu/all_together/energy_songSmall_dual_truncatedFreq", step = step),
            # "Energy-Dual": load_exp("multigpu/all_together/energy_song_score_anisoEmb_groupNorm_lambda0_mult_lr2e-4_lrdecay50000_1000warmup_d2_correct_truncatedDeblur_bs512_onlyspatial", step = step),
        }
    elif args_cmd.dataset == "Celeba":
        args = load_args("multigpu/all_together/energy_songSmall_dual_truncatedFreq_celeba", step = "best")
        step = "last"
        ctxs = {
            "Energy-Dual": load_exp("multigpu/all_together/energy_songSmall_dual_truncatedFreq_celeba", step = step),
        }

    loss_fn_alex = lpips.LPIPS(net='alex').cuda() # best forward scores

    time_tracker: TimeTracker = TimeTracker()
    time_tracker.switch("initialization")

    default_ctx = ctxs["Energy-Dual"]
    device = default_ctx.device
    dataset_info = default_ctx.dataset_info
    d = dataset_info.dimension

    ## Load data 
    test_batch_size = 128
    img_size = dataset_info.spatial_size
    CHW = 3 * img_size * img_size
    
    train_dataloader, test_dataloader, dataset_info = load_data(
        dataset=args["dataset"], spatial_size=args["spatial_size"], grayscale=args["grayscale"], data_subset=eval(args["data_subset"]),
        train_batch_size=test_batch_size, test_batch_size=test_batch_size, num_workers=args["num_workers"], seed=2
    )

    images = next(iter(test_dataloader))  # load for testing things.
    shape = (images[0].shape[0], CHW)
 

    for batch_idx, images in enumerate(test_dataloader):
        if batch_idx % 1 == 0 and batch_idx > 0:
            break

        num_variances = 20
        psnr_min = 0
        psnr_max = 50
        psnrs = torch.linspace(psnr_min, psnr_max, num_variances , device=default_ctx.device)  # (L,), steps of 7.5dB
        noise_levels = DenoisingError(dataset_info=default_ctx.dataset_info, psnr=psnrs).to_noise_level().variance  # (L,)
        
        energyDual_denoiser_error_all_scales = torch.zeros(num_variances)
        energy_denoiser_error_all_scales = torch.zeros(num_variances)
        denoiser_error_all_scales = torch.zeros(num_variances)

        for idx in range(num_variances):
            ## Sampling and noise parameters
            sigma_begin = noise_levels[idx]
            sigma_end = 1e-2
            num_classes = 1000
            sigma_dist = 'geometric'
            sigmas = get_sigmas(np.sqrt(sigma_begin.cpu().numpy()), sigma_end, num_classes, device, sigma_dist)
            # Sampling parameters
            n_steps_each = 0
            batch_size = 128
            snr = 0.15
            step_lr = None
            temp = 1
            blind = False
            # Parameters of the degradation
            sigma_box_init = sigma_begin
            sigma_clean_init = 1e-5
            inp_mask_type = "box"
            box_size = 10 #19

            # Build covariance and load images
            idx_img = 0 
            cov = spatial_corr_covariance_testing(spatial_size=img_size, box_size=box_size, var_box=sigma_box_init, device=device, var_clean = sigma_clean_init, half_box_size = box_size, inp_mask_type=inp_mask_type)
            # cov = spatial_corr_covariance_testing(spatial_size=img_size, box_size=box_size, var_box=0, device=device, var_clean = 1, half_box_size = box_size, inp_mask_type=inp_mask_type)

            x_lists = {f'samples_{ctxs["Energy-Dual"].args.model}': None, 
                        # f'samples_{ctxs["Energy-Single"].args.model}': None,
                    # f'samples_{ctxs["Denoiser"].args.model}': None
                }


            ## Run sampler
            print(f"Running sampling with noise level {noise_levels[idx]}..")

            clean_images = images[0][idx_img:idx_img+batch_size,:,:,:].cuda()
            print(clean_images.shape)
            for model_, ctx in ctxs.items():
                if model_ == "Energy-Single":
                    continue
                mse = 0
                lpips_ = 0
                print(f"Running {model_} ...")
                
                x_init = clean_images + cov.apply_power(torch.randn_like(clean_images).cuda(), p=0.5)
                # x_init = cov.apply_power(clean_images, p = 1) + sigma_clean_init * torch.randn_like(images[0][0:batch_size,:,:,:]).cuda()

                # Compute the size of the box
                if blind == True:
                    box_size_hat = compute_sigma(x_init, ctx)
                else:
                    box_size_hat = box_size
                print(f"Box size estimated: {box_size_hat}, True box size:{box_size}")

                # Run sampler
                x = PC_sampler(x_init, ctx.model, sigmas, n_steps_each=n_steps_each, step_lr=step_lr, batch_size=batch_size, verbose = False, final_only=True, denoise=False, 
                                            data_score = True, inp_mask_type=inp_mask_type, temp = temp, box_size = box_size_hat,
                                            snr = snr)

                
                
                samples = torch.clamp(x[0], 0.0, 1.0)
                mse = mse + torch.mean((images[0][idx_img:idx_img+batch_size,:,:,:] - samples[0:batch_size])**2, dim=(-1, -2, -3)).sum() 
                for ii in range(batch_size):
                    lpips_ += loss_fn_alex(images[0][ii+idx_img:ii+idx_img+1,:,:,:].cuda(),  samples[ii:ii+1].cuda())

                if model_ == "Energy-Single":
                    energy_denoiser_error_all_scales[idx] = mse / batch_size
                elif model_ == "Energy-Dual":
                    energyDual_denoiser_error_all_scales[idx] = mse / batch_size
                
                print("MSE", mse / (batch_size))
                print("LPIPS", lpips_ / (batch_size))
                x_lists[f'samples_{ctx.args.model}'] = x[0]

                nrow = 4
                grid = torchvision.utils.make_grid(clean_images.cuda().cpu(), nrow=nrow, padding=2)
                grid_clean = grid.permute(1, 2, 0).numpy()

                # Display the grid
                plt.figure(figsize=(10, 10)) # Adjust figure size as needed
                plt.imshow(grid_clean)
                plt.axis('off') # Hide axes
                # plt.title("Clean images")
                plt.tight_layout() # Hide axes
                plt.savefig(f'samples/clean_images_batch_{batch_idx}.pdf')

                grid = torchvision.utils.make_grid(x_init.cpu(), nrow=nrow, padding=2)
                grid_np = grid.permute(1, 2, 0).numpy()
            
                # Display the grid
                plt.figure(figsize=(10, 10)) # Adjust figure size as needed
                plt.imshow(grid_np)
                plt.axis('off') # Hide axes
                # plt.title("Noisy images")
                plt.tight_layout()
                plt.savefig(f'samples/noisy_images_batch_{batch_idx}.pdf')
            
                samples = torch.clamp(x_lists[f'samples_{ctxs["Energy-Dual"].args.model}'], 0.0, 1.0)
                grid = torchvision.utils.make_grid(samples, nrow=nrow, padding=2, normalize=False)
                grid_np = grid.permute(1, 2, 0).numpy()
            
                # Display the grid
                plt.figure(figsize=(10, 10)) # Adjust figure size as needed
                plt.imshow(grid_np)
                plt.axis('off') # Hide axes
                # plt.title("Generated images - Energy")
                plt.tight_layout()
                plt.savefig(f'samples/generated_sample_batch_energy_{model_}_{batch_idx}.pdf')
            
        
    results_file_name = f"denoising_error_posteriorsampling_onlypred_energyvsdenoiser_song_{step}.pt"
    torch.save({
        'psnrs': psnrs,
        'energy_denoiser_denoising': energy_denoiser_error_all_scales,
        'energyDual_denoiser_denoising': energyDual_denoiser_error_all_scales,
        'mask_type': inp_mask_type,
    }, f'plots/{results_file_name}') 
    print(f"Saved in plots/{results_file_name}")