from itertools import cycle

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import get_dataset
from models import get_model


def train_rectified_flow(config):
    train_set, test_set = get_dataset(config)
    config.train_set, config.test_set = train_set, test_set

    model_cls = get_model(config.model)
    model = model_cls(**config.model_config).to(config.device)
    flow = config.flow(model=model, trajectory=config.trajectory,
                       config=config, reparam=config.repara)
    state_dict = torch.load(config.model_path + f'_iter_{config.iterations}',
                            map_location=config.device)['model']
    new_state_dict = {k.replace("model.model.", "model."): v for k, v in state_dict.items()}
    flow.model.load_state_dict(new_state_dict)


    rectified_model_cls = get_model(config.rectified_model)
    rectified_model = rectified_model_cls(**config.rectified_model_config).to(config.device)
    optimizer = torch.optim.Adam(rectified_model.parameters(), lr=config.learning_rate)

    loss_curve = []
    current_epoch = config.rectified_current_epoch

    if current_epoch not in [None, 0]:
        rectified_model.load_state_dict(torch.load(config.model_path + f'_epoch_{current_epoch}',
                                              map_location=config.device)['model'])
        optimizer.load_state_dict(torch.load(config.model_path + f'_epoch_{current_epoch}',
                                             map_location=config.device)['optimizer'])
        loss_curve = torch.load(config.model_path + f'_epoch_{current_epoch}',
                                map_location=config.device)['loss']

    patience = 0
    best_loss = 99999999.
    for epoch in range(config.rectified_epoch_num + 1):
        c = None
        if config.condition_type is not None:
            data_loader = cycle(DataLoader(train_set, batch_size=config.rectified_sampling_num, shuffle=True, num_workers=16))
            condition_data, param = next(iter(data_loader))
            condition_data = condition_data.to(config.device)
            c = config.train_set.get_condition(condition_data, type=config.condition_type)

        if config.sampling == 'dflow':
            samples, gaussian = flow.sample_dflow(num_samples=config.rectified_sampling_num,
                                        num_steps=config.sampling_steps, batch_size=256,return_z0=True,
                                        c=c, param=param)
        else:
            samples, gaussian = flow.sample_ode(num_samples=config.rectified_sampling_num, num_steps=config.sampling_steps,
                                            batch_size=256, return_z0=True, c=c, param=param)

        if config.conditional_model == False and config.sampling == 'vanilla':
            pass
        else:
            data_error = torch.mean((samples - condition_data) ** 2)
            print(f'Dataset error: { data_error.detach().cpu().numpy()}')

        num_batches = samples.shape[0] // config.rectified_batch_size
        if num_batches == 0:
            num_batches = 1

        pbar = tqdm(total=num_batches * config.rectified_epoch_repeat, leave=False)
        for _ in range(config.rectified_epoch_repeat):

            for i in range(num_batches):
                start_idx = i * config.rectified_batch_size
                end_idx = start_idx + config.rectified_batch_size

                if config.rectified_type == 'generation':
                    batch_input = gaussian[start_idx:end_idx]
                    batch_output = samples[start_idx:end_idx]
                    true_output = condition_data[start_idx:end_idx]
                elif config.rectified_type == 'forward':
                    batch_input = c[start_idx:end_idx]
                    batch_output = samples[start_idx:end_idx]
                    true_output = condition_data[start_idx:end_idx]
                elif config.rectified_type == 'inverse':
                    batch_input = samples[start_idx:end_idx]
                    mask = torch.ones_like(batch_input)
                    mask = config.train_set.get_condition(mask, type=config.condition_type)
                    mask = mask < 0.5
                    batch_input[mask] = 0
                    batch_output = samples[start_idx:end_idx]
                    true_output = condition_data[start_idx:end_idx]
                else:
                    raise NotImplementedError
                pred_samples = rectified_model(batch_input)

                data_loss = ((pred_samples - batch_output) ** 2).mean() * config.rectifed_data_loss_weight
                loss = data_loss
                pde_loss = config.train_set.compute_pde_error(pred_samples, )
                pde_loss = (pde_loss ** 2).mean() * config.rectifed_pde_loss_weight
                loss = loss + pde_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pbar.update(1)
        pbar.close()
        print(f"Epoch [{epoch}/{config.rectified_epoch_num}], Data Loss: {(data_loss/config.rectifed_data_loss_weight).item()}, ")
        print(f"PDE Loss: {(pde_loss/config.rectifed_pde_loss_weight).item():.6f}")

        early_stop_metric = data_loss / config.rectifed_data_loss_weight
        early_stop_metric = early_stop_metric.detach().cpu().numpy()

        if early_stop_metric < best_loss:
            best_loss = early_stop_metric
            patience = 0
            print(f"True Loss: {early_stop_metric}, best record appears! patience: {patience}/{config.rectified_epoch_num}")
            torch.save({
                'loss': loss_curve,
                'model': rectified_model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, config.rectified_model_path + f'_best')
        else:
            patience += 1
            print(f"True Loss: {early_stop_metric}, patience: {patience}/{config.rectified_patience}")
            if patience >= config.rectified_patience:
                print(f"Early stopping...")
                break


if __name__ == '__main__':
    from config.poisson_config import TrainConfig
    train_rectified_flow(TrainConfig())
