import torch
import time
from sklearn.metrics import r2_score

from flow.transport.transport import create_transport, Sampler
from experiment.train_utils import ODEConfig, SDEConfig

#################################################################################
#                                  Sampling Loop                                #
#################################################################################

def sample_process(mode, args, de_config):
    assert args.model is not None
    model = args.model
    device = args.device
    # load pre-trained model
    model.eval()
    model = model.to(device)
    
    transport = create_transport(
        args.path_type,
        args.prediction,
        args.loss_weight,
        args.train_eps,
        args.sample_eps
    ) 
    sampler = Sampler(transport)

    # mode: SDE/ODE
    if mode == "ODE":
        if args.likelihood:
            assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
            sample_fn = sampler.sample_ode_likelihood(
                sampling_method=de_config.sampling_method,
                num_steps=args.num_sampling_steps,
                atol=de_config.atol,
                rtol=de_config.rtol,
            )
        else:
            sample_fn = sampler.sample_ode(
                sampling_method=de_config.sampling_method,
                num_steps=args.num_sampling_steps,
                atol=de_config.atol,
                rtol=de_config.rtol,
                reverse=de_config.reverse
            )
            
    elif mode == "SDE":
        sample_fn = sampler.sample_sde(
            sampling_method=de_config.sampling_method,
            diffusion_form=de_config.diffusion_form,
            diffusion_norm=de_config.diffusion_norm,
            last_step=de_config.last_step,
            last_step_size=de_config.last_step_size,
            num_steps=args.num_sampling_steps,
        )
    
    assert args.sample_generator is not None
    start = time.time()
    # sampling batches
    for x, y, batch_num in args.sample_generator:
        sample_batch_x = torch.tensor(x).to(device)
        sample_batch_y = torch.tensor(y).to(device)

        # noisy latent features
        z_0 = torch.randn(batch_num, model.hidden_size, device=device)
        z_0 = torch.unsqueeze(z_0, dim=1)
    
        samples = sample_fn(z_0, model.forward_with_cfg, **model_kwargs)[-1]

        # decoding behavior labels
        # inverse weights (d_model * d_pos)
        pinv_decoder = torch.linalg.pinv(model.linear_encoder.weight.t())
        dec_out_test = (samples - model.linear_encoder.bias) @ pinv_decoder

        r2_score_tmp = r2_score(sample_batch_y.cpu().detach().numpy(), dec_out_test.cpu().detach().numpy())
        print("test r2 score: %.4f" % r2_score_tmp)
    
    print(f"Sampling took {time() - start_time:.2f} seconds.")