import os.path
from itertools import cycle

import numpy as np
import torch
from matplotlib import pyplot as plt, ticker
from torch.utils.data import DataLoader

from datasets import get_dataset
from models import get_model
from util.visualize import visualize_batch, calculate_metrics

torch.manual_seed(1)


def evaluate_flow(config):
    train_set, test_set = get_dataset(config)
    config.train_set, config.test_set = train_set, test_set

    model_cls = get_model(config.model)
    model = model_cls(**config.model_config).to(config.device)
    flow = config.flow(model=model, trajectory=config.trajectory,
                       config=config, reparam=config.repara)

    state_dict = torch.load(config.model_path + f'_iter_{config.iterations}',
                            map_location=config.device)['model']
    new_state_dict = {k.replace("model.model.", "model."): v for k, v in state_dict.items()}
    flow.model.load_state_dict(new_state_dict)

    c = None
    num_samples = 5
    if config.condition_type is not None:
        data_loader = cycle(DataLoader(train_set, batch_size=num_samples, shuffle=False, num_workers=16))
        condition_data, param = next(iter(data_loader))
        condition_data = condition_data.to(config.device)
        c = config.train_set.get_condition(condition_data, type=config.condition_type)

    if config.sampling == 'dflow':
        samples = flow.sample_dflow(num_samples=num_samples,
                                  num_steps=config.sampling_steps, batch_size=100,
                                  c=c, param=param)
    else:
        samples = flow.sample_ode(num_samples=num_samples,
                                  num_steps=config.sampling_steps, batch_size=100,
                                  c=c, param=param)

    if config.condition_type is not None:
        data_error = torch.mean((samples - condition_data) ** 2)
        print(f'Data Loss: {data_error.detach().cpu().numpy()}')
        calculate_metrics(condition_data, samples)


    pde_error = train_set.compute_pde_error(samples, **param)
    pde_error = torch.mean(pde_error ** 2)
    print(f'PDE Loss: {pde_error.abs().detach().cpu().numpy()}')

    mmse = ((samples.mean(0) - condition_data.mean(0))**2).mean()
    smse = ((samples.std(0) - condition_data.std(0))**2).mean()
    print(f'MMSE Loss: {mmse.detach().cpu().numpy()}')
    print(f'SMSE Loss: {smse.detach().cpu().numpy()}')

if __name__ == '__main__':
    from config.poisson_config import TrainConfig
    evaluate_flow(TrainConfig())
