from itertools import cycle

import numpy as np
import torch
from matplotlib import pyplot as plt, ticker
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import get_dataset
from models import get_model
from util.visualize import visualize_batch


def evaluate_rectified_flow(config):
    config.sampling = 'vanilla'
    train_set, test_set = get_dataset(config)
    config.train_set, config.test_set = train_set, test_set


    rectified_model_cls = get_model(config.rectified_model)
    rectified_model = rectified_model_cls(**config.rectified_model_config).to(config.device)
    rectified_model.load_state_dict(torch.load(config.rectified_model_path + f'_best',
                                               map_location=config.device)['model'])

    size = 1000

    data_loader = cycle(
        DataLoader(train_set, batch_size=size, shuffle=False, num_workers=16))
    condition_data, param = next(iter(data_loader))
    condition_data = condition_data.to(config.device)
    c = config.train_set.get_condition(condition_data, type=config.condition_type)

    if config.rectified_type == 'forward':
        inputs = c
        true_output = condition_data
        mask = torch.ones_like(inputs)
        mask = config.train_set.get_condition(mask, type=config.condition_type)
        mask = mask > 0.5
    elif config.rectified_type == 'inverse':
        inputs = condition_data
        mask = torch.ones_like(inputs)
        mask = config.train_set.get_condition(mask, type=config.condition_type)
        mask = mask < 0.5
        inputs[mask] = 0
        true_output = condition_data
    elif config.rectified_type == 'generation':
        resolution = train_set.resolution
        z = torch.randn((size, resolution[0], resolution[1], resolution[2])).to(config.device).requires_grad_()

        mask = torch.ones_like(z)
        mask = config.train_set.get_condition(mask, type=config.condition_type)
        mask = mask > 0.5

        true_output = condition_data

        cnt = 0
        def loss_fn(x0):
            nonlocal cnt
            x = x0
            x = x / torch.std(x, dim=(1,2,3), keepdim=True)
            x = rectified_model(x)
            data_loss = ((x - true_output) * mask).square().sum() * config.rectifed_data_loss_weight
            data_loss_mean = ((x - true_output) * mask).square().mean()

            _x = x
            _x[mask] = true_output[mask]

            pde_error = train_set.compute_pde_error(_x, **param)
            pde_error_mean = torch.mean(pde_error ** 2)
            pde_error = torch.sum(pde_error ** 2) * config.rectifed_pde_loss_weight
            loss = data_loss + pde_error
            print(f'Iter {cnt}: data: {data_loss_mean.item()}, pde: {pde_error_mean.item()}')
            return x, loss
        def closure():
            nonlocal cnt
            cnt += 1
            optimizer.zero_grad()
            _, loss = loss_fn(z)
            loss.backward(retain_graph=False)
            return loss
        optimizer = torch.optim.LBFGS([z], max_iter=50, lr=5e-1, line_search_fn='strong_wolfe')
        inputs = z
        optimizer.step(closure)

    inputs = inputs / torch.std(inputs, dim=(1,2,3), keepdim=True)
    input_batches = torch.split(inputs, 128)

    samples = []
    for inupt_batch in input_batches:
        with torch.no_grad():
            samples_batch = rectified_model(inupt_batch)
            samples.append(samples_batch)
    samples = torch.cat(samples, dim=0)
    samples[mask] = true_output[mask]
    # train_set.visualize_batch(samples=samples.detach().cpu())

    # data_error = ((samples - condition_data) ** 2).mean()
    data_error = torch.mean((samples - true_output) ** 2)
    print(f'Data Loss: {data_error.detach().cpu().numpy()}')

    pde_error = train_set.compute_pde_error(samples, **param)
    pde_error = torch.mean(pde_error ** 2)
    print(f'PDE Loss: {pde_error.detach().cpu().numpy()}')

    mmse = ((samples.mean(0) - condition_data.mean(0))**2).mean()
    smse = ((samples.std(0) - condition_data.std(0))**2).mean()
    print(f'MMSE Loss: {mmse.detach().cpu().numpy()}')
    print(f'SMSE Loss: {smse.detach().cpu().numpy()}')


if __name__ == '__main__':
    from config.poisson_config import TrainConfig
    evaluate_rectified_flow(TrainConfig())
