import wandb
import torch
from tqdm import tqdm
from pathlib import Path
from ema_pytorch import EMA
from omegaconf import DictConfig
from hydra.utils import instantiate
from lightning.fabric.utilities import AttributeDict

import utils

import logging

log = logging.getLogger(__name__)


def get_fabric(config):
    """Instantiate Fabric object, set seed and launch."""
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric


def get_components(config, fabric):
    """
    Instantiate all torch objects with their optional optimizers and setup with Fabric.
    To avoid problems with distributed strategies when parallelizing objects with no parameters
    to train, we additionally set .requires_grad_(True) on each of them.
    """
    # init diffusion
    diffusion = fabric.setup(instantiate(config.diffusion))
    preprocess_fn = instantiate(config.preprocess_fn)
    return diffusion, preprocess_fn


def get_train_dataloader(config, fabric):
    """Instantiate dataloaders and setup with Fabric."""
    return fabric.setup_dataloaders(instantiate(config.train_dataloader))


def run(config: DictConfig):
    torch.multiprocessing.set_start_method("spawn")
    utils.hydra.preprocess_config(config)
    utils.wandb.setup_wandb(config)

    # output logging directory path
    log.info(f"config.exp.log_dir={str(config.exp.log_dir)}")

    log.info("Launching Fabric")
    fabric = get_fabric(config)

    # context manager to automatically move newly created tensors to correct device
    with fabric.init_tensor():
        log.info("Initializing components")
        diffusion, preprocess_fn = get_components(config, fabric)

        # extract sde and measurement system
        sde = diffusion.sde
        system = sde.measurement_system
        timestep_sampler = diffusion.timestep_sampler

        # get timesteps
        ts, dts = timestep_sampler.get_path(config.exp.n_steps)

        log.info("Initializing dataloaders")
        train_dataloader = get_train_dataloader(config, fabric)

        log.info("Looping over training dataloader")

        for batch in train_dataloader:
            for x_0 in batch:
                # add batch shape
                x_0 = x_0.unsqueeze(0).repeat(ts.shape[0], 1, 1, 1)

                # reset random state of the system
                system.mean_system_response.reset_random_state(x_0, fabric, eval=False)
                if hasattr(system.mean_system_response, "idx"):
                    system.mean_system_response.idx = fabric.to_device(
                        torch.tensor([0] * ts.shape[0])
                    )

                # get trajectory of x_t, y and A^+ \cdot y
                x_ts, pinvA_y, y = sde.sample_x_t_given_x_0(x_0, ts)

                # decompose into range and null
                x_ts_range = system.pinvA(system.mean_system_response(x_ts))
                x_ts_null = x_ts - x_ts_range

                # preprocess if needed
                x_0 = preprocess_fn(x_0[0].unsqueeze(0))
                x_ts_null = preprocess_fn(x_ts_null)
                x_ts_range = preprocess_fn(x_ts_range)

                # log
                wandb.log(
                    {
                        "images/x_0": wandb.Image(x_0[0]),
                        "images/y": wandb.Image(
                            torch.sqrt(y[:, 0] ** 2 + y[:, 1] ** 2)
                            .log1p()
                            .unsqueeze(1)[0]
                            if y.shape[1] == 2
                            else y[0]
                        ),
                        "images/pinvA_y": wandb.Image(pinvA_y[0]),
                    }
                )

                for x_t, x_t_range, x_t_null in zip(x_ts, x_ts_range, x_ts_null):
                    wandb.log(
                        {
                            "images/x_t": wandb.Image(x_t),
                            "images/x_t_range": wandb.Image(x_t_range),
                            "images/x_t_null": wandb.Image(x_t_null),
                        }
                    )
