import argparse
from datetime import datetime
from typing import Callable

import torch
import wandb
from matplotlib import pyplot as plt

from neural_mpm.data.data_manager import (
    DataManager,
)
import neural_mpm.util.interpolation as interp
from neural_mpm.util.metaparams import DIM, LOW, HIGH
import neural_mpm.util.metaparams as metaparams
import neural_mpm.util.simulate as simulate
from neural_mpm.nn import FNO, FFNO, UNet
from neural_mpm.util import viz
from neural_mpm.util.model_logger import ModelLogger

from tqdm import tqdm


# from torch.profiler import profile, record_function, ProfilerActivity


def loss_fn(
    module,
    grids,
    target_grids,
    state,
    targets,
    gmean,
    gstd,
    types,
    step_fn: Callable,
    steps_per_call: int,
    grid_size: int = 64,
    autoregressive_steps: int = 1,
    debug: int = -1,
):
    loss = 0
    prev_pos = state[..., :DIM]
    total_steps = autoregressive_steps * steps_per_call
    target_grids = target_grids.permute(0, 2, 3, 1, -1)

    norm = torch.count_nonzero(types) * steps_per_call * DIM



    # TODO: this
    hack = grids.shape[-1] == 6
    if hack:
        grav = grids[..., -2:]

    # if debug % 500 == 0:
    #     fig, axs = plt.subplots(3, 2, figsize=(10, 15))
    #     t = 0
    #
    #     im0 = axs[0, 0].imshow(
    #         grids[t, ..., 0].detach().cpu().numpy(),
    #         cmap='viridis', origin='lower')
    #     axs[0, 0].set_title('velx')
    #     fig.colorbar(im0, ax=axs[0, 0], orientation='vertical')
    #
    #     im1 = axs[0, 1].imshow(
    #         grids[t, ..., 1].detach().cpu().numpy(),
    #         cmap='viridis', origin='lower')
    #     axs[0, 1].set_title('vely')
    #     fig.colorbar(im1, ax=axs[0, 1], orientation='vertical')
    #
    #     im2 = axs[1, 0].imshow(
    #         grids[t, ..., 2].detach().cpu().numpy(),
    #         cmap='viridis', origin='lower')
    #     axs[1, 0].set_title('density')
    #     fig.colorbar(im2, ax=axs[1, 0], orientation='vertical')
    #
    #     im3 = axs[1, 1].imshow(
    #         grids[t, ..., 3].detach().cpu().numpy(),
    #         cmap='viridis', origin='lower')
    #     axs[1, 1].set_title('wall')
    #     fig.colorbar(im3, ax=axs[1, 1], orientation='vertical')
    #
    #     im4 = axs[2, 0].imshow(
    #         grids[t, ..., 4].detach().cpu().numpy(),
    #         cmap='viridis', origin='lower')
    #     axs[2, 0].set_title('gravx')
    #     fig.colorbar(im4, ax=axs[2, 0], orientation='vertical')
    #
    #     im5 = axs[2, 1].imshow(
    #         grids[t, ..., 5].detach().cpu().numpy(),
    #         cmap='viridis', origin='lower')
    #     axs[2, 1].set_title('gravy')
    #     fig.colorbar(im5, ax=axs[2, 1], orientation='vertical')
    #     plt.savefig(f'tmp/input_{debug}.png')

    for i in range(0, total_steps, steps_per_call):
        points = (grids - gmean) / gstd
        preds = module(points)
        preds = preds * gstd[None, ..., :DIM] + gmean[None, ..., :DIM]
        pred_pos, grids = step_fn(
            preds,
            prev_pos,
            types,
        )

        # TODO: this
        if hack:
            grids = torch.cat((grids, grav), axis=-1)

        real_loss = (
            targets[:, i : i + steps_per_call, ..., :DIM] - pred_pos[..., :DIM]
        ) ** 2
        real_loss = torch.where(types[:, None, :, None] > 0.0, real_loss, 0.0)

        real_loss = real_loss.sum() * 64**2 / norm
        loss += real_loss
        prev_pos = pred_pos[:, -1, ..., :DIM]
        # grid_loss = (
        #     target_grids[..., i : i + steps_per_call, :DIM] - preds
        # ) ** 2
        # loss += grid_loss.mean() * grid_size ** 2
        #
        # if debug % 500 == 0 and i == 2 * steps_per_call:
        #     # print(grids.shape, target_grids.shape)
        #     # torch.Size([16, 128, 128, 6])
        #     # torch.Size([16, 128, 128, 64, 6])
        #     # exit()
        #     # fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        #     # t = 0
        #     #
        #     # im0 = axs[0].imshow(
        #     #     (preds[t, ..., 0, 0]).detach().cpu().numpy(),
        #     #     cmap='viridis', origin='lower')
        #     # axs[0].set_title('Pred')
        #     # fig.colorbar(im0, ax=axs[0], orientation='vertical')  # Add colorbar to the first subplot
        #     #
        #     # im1 = axs[1].imshow(
        #     #     target_grids[t, ..., 0, 0].detach().cpu().numpy(),
        #     #     cmap='viridis', origin='lower')
        #     # axs[1].set_title('GT')
        #     # fig.colorbar(im1, ax=axs[1], orientation='vertical')  # Add colorbar to the second subplot
        #     #
        #     # # Save the figure to disk
        #     # plt.savefig(f'tmp/fig_water{debug}.png')
        #     fig, axs = plt.subplots(3, 2, figsize=(10, 15))
        #     t = 0
        #
        #     im0 = axs[0, 0].imshow(
        #         grids[t, ..., 0].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[0, 0].set_title('velx')
        #     fig.colorbar(im0, ax=axs[0, 0], orientation='vertical')
        #
        #     im1 = axs[0, 1].imshow(
        #         grids[t, ..., 1].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[0, 1].set_title('vely')
        #     fig.colorbar(im1, ax=axs[0, 1], orientation='vertical')
        #
        #     im2 = axs[1, 0].imshow(
        #         grids[t, ..., 2].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[1, 0].set_title('density')
        #     fig.colorbar(im2, ax=axs[1, 0], orientation='vertical')
        #
        #     im3 = axs[1, 1].imshow(
        #         grids[t, ..., 3].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[1, 1].set_title('wall')
        #     fig.colorbar(im3, ax=axs[1, 1], orientation='vertical')
        #
        #     im4 = axs[2, 0].imshow(
        #         grids[t, ..., 4].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[2, 0].set_title('gravx')
        #     fig.colorbar(im4, ax=axs[2, 0], orientation='vertical')
        #
        #     im5 = axs[2, 1].imshow(
        #         grids[t, ..., 5].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[2, 1].set_title('gravy')
        #     fig.colorbar(im5, ax=axs[2, 1], orientation='vertical')
        #     plt.savefig(f'tmp/output_{debug}.png')
        #
        #
        #     ###################################
        #
        #     fig, axs = plt.subplots(3, 2, figsize=(10, 15))
        #
        #     im0 = axs[0, 0].imshow(
        #         target_grids[t, ..., i + 1, 0].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[0, 0].set_title('velx')
        #     fig.colorbar(im0, ax=axs[0, 0], orientation='vertical')
        #
        #     im1 = axs[0, 1].imshow(
        #         target_grids[t, ..., i + 1, 1].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[0, 1].set_title('vely')
        #     fig.colorbar(im1, ax=axs[0, 1], orientation='vertical')
        #
        #     im2 = axs[1, 0].imshow(
        #         target_grids[t, ..., i + 1, 2].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[1, 0].set_title('density')
        #     fig.colorbar(im2, ax=axs[1, 0], orientation='vertical')
        #
        #     im3 = axs[1, 1].imshow(
        #         target_grids[t, ..., i + 1, 3].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[1, 1].set_title('wall')
        #     fig.colorbar(im3, ax=axs[1, 1], orientation='vertical')
        #
        #     im4 = axs[2, 0].imshow(
        #         target_grids[t, ..., i + 1, 4].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[2, 0].set_title('gravx')
        #     fig.colorbar(im4, ax=axs[2, 0], orientation='vertical')
        #
        #     im5 = axs[2, 1].imshow(
        #         target_grids[t, ..., i + 1, 5].detach().cpu().numpy(),
        #         cmap='viridis', origin='lower')
        #     axs[2, 1].set_title('gravy')
        #     fig.colorbar(im5, ax=axs[2, 1], orientation='vertical')
        #     plt.savefig(f'tmp/target_{debug}.png')





    return loss, preds


def train(
    model,
    datamanager,
    optimizer,
    schedulers,
    use_wandb=True,
    particle_noise: float = metaparams.position_noise,
    grid_noise: float = metaparams.grid_noise,
    passes_over_buffer: int = 3,
    epochs: int = 10,
    model_logger: ModelLogger = None,
    progress_bars=False,
):
    if model_logger is not None:
        model_logger.start_timer()

    total_step = 0
    best_val = torch.inf

    if torch.cuda.is_available():
        grid_noise = torch.tensor(grid_noise, device='cuda')
        particle_noise = torch.tensor(particle_noise, device='cuda')
    else:
        grid_noise = torch.tensor(grid_noise)
        particle_noise = torch.tensor(particle_noise)

    step_fn = simulate.get_step_fn(
        datamanager.grid_coords,
        datamanager.interp_fn,
        datamanager.size,
        datamanager.interaction_radius,
    )

    epochs_iter = range(epochs)
    if progress_bars:
        epochs_iter = tqdm(epochs_iter, desc="Epochs")
    for e in epochs_iter:
        logs = {"epoch": e}
        if use_wandb:
            wandb.log(logs, step=total_step)

        dataloader = datamanager.get_dataloader()
        gmean, gstd = datamanager.gmean, datamanager.gstd

        passes_iter = range(passes_over_buffer)
        if progress_bars:
            passes_iter = tqdm(passes_iter, desc="Buffer Passes", leave=False)

        for pss in passes_iter:
            if torch.cuda.is_available():
                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()

            batches_iter = dataloader
            if progress_bars:
                batches_iter = tqdm(batches_iter, desc="Batches", leave=False)

            for batch in batches_iter:
                grids, states, targets, types, target_grids = batch

                if torch.cuda.is_available():
                    grids, states, targets, types, target_grids = (
                        grids.to("cuda"),
                        states.to("cuda"),
                        targets.to("cuda"),
                        types.to("cuda"),
                        target_grids.to("cuda"),
                    )

                grids = grids + torch.randn_like(grids) * grid_noise
                noisy_states = states + torch.randn_like(states) * particle_noise
                states = torch.where(types[..., None] > 0.0, noisy_states, states)
                states = torch.clamp(states, min=LOW, max=HIGH)

                loss, next_points = loss_fn(
                    model,
                    grids,
                    target_grids,
                    states,
                    targets,
                    gmean,
                    gstd,
                    types,
                    step_fn,
                    steps_per_call=datamanager.steps_per_call,
                    autoregressive_steps=datamanager.autoregressive_steps,
                    debug=total_step,
                    grid_size=datamanager.grid_size
                )

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

                if schedulers is not None and total_step <= int(1e5):
                    if total_step >= schedulers["cosine_start"]:
                        schedulers["cosine"].step()
                        current_lr = schedulers["cosine"].get_last_lr()
                    elif total_step < schedulers["warmup_end"]:
                        schedulers["linear"].step()
                        current_lr = schedulers["linear"].get_last_lr()

                else:
                    current_lr = optimizer.param_groups[0]["lr"]

                current_lr = (
                    current_lr[0] if isinstance(current_lr, list) else current_lr
                )

                if total_step % 50 == 0:
                    logs = {
                        "loss": loss.item(),
                        "log_loss": torch.log10(loss).item(),
                        "total_step": total_step,
                        "sim_idx": datamanager.current_sim_idx,
                        "lr": current_lr,
                    }

                    if use_wandb:
                        wandb.log(logs, step=total_step)

                    if model_logger is not None:
                        model_logger.try_saving(model)

                total_step += 1

        if e % 5 == 0:
            with torch.no_grad():
                init_state, valid_types, valid_sims = datamanager.get_valid_sims()
                sim_length = datamanager.true_sim_length - 1
                if sim_length > 1000:
                    sim_length = 1000
                num_calls = sim_length // datamanager.steps_per_call + 1

                _, trajectories, new_grids = simulate.unroll(
                    model,
                    init_state,
                    datamanager.grid_coords,
                    num_calls,
                    gmean,
                    gstd,
                    valid_types,
                    datamanager.size,
                    interaction_radius=datamanager.interaction_radius,
                    interp_fn=datamanager.interp_fn,
                )

                time_err = []
                for i in range(len(valid_sims)):
                    n = torch.count_nonzero(valid_types[i])
                    te = (
                        valid_sims[i][1:sim_length, :n, :DIM]
                        - trajectories[i, : sim_length - 1, :n, :DIM]
                    ) ** 2

                    te = te.mean(axis=(1, 2))
                    time_err += [te]

                time_err = sum(time_err) / len(time_err)
                err = time_err.mean()

                plot_idx = torch.randint(0, len(valid_sims), (1,)).item()

                last_sim = valid_sims[plot_idx]
                last_type = valid_types[plot_idx]
                last_traj = trajectories[plot_idx]

                logs = {
                    "total_step": total_step,
                    "total_error": err,
                }
                if err <= best_val:
                    best_val = err

                    if model_logger:
                        model_logger.save_model(
                            model,
                            checkpoint_name="best",
                            json_dict={
                                "total_step": total_step,
                                "total_error": err.item(),
                                "epoch": e,
                                "pass": pss,
                                "loss": loss.item(),
                                "log_loss": torch.log10(loss).item(),
                            },
                        )

                    data = [
                        [x.item(), y.item()]
                        for (x, y) in zip(torch.arange(time_err.shape[0]), time_err)
                    ]
                    table = wandb.Table(data=data, columns=["t", "err"])

                    if use_wandb:
                        time_err_plt = wandb.plot.line(
                            table, "t", "err", title="Error over time"
                        )
                        logs["error_over_time"] = time_err_plt

                    length = sim_length - 1
                    vid = viz.animate_comparison(
                        last_traj[:length].detach().cpu().numpy(),
                        truth=last_sim[1 : sim_length + 1, ..., :2]
                        .detach()
                        .cpu()
                        .numpy(),
                        type_=last_type.detach().cpu().numpy(),
                        interval=1,
                        save_path=None,
                        return_ani=False,
                        as_array=True,
                        bounds=((LOW, HIGH), (LOW, HIGH)),
                    )

                    logs["sim"] = wandb.Video(vid, fps=16)

                if use_wandb:
                    wandb.log(logs, step=total_step)

                del init_state, valid_types, valid_sims

        if torch.cuda.is_available():
            end.record()
            torch.cuda.synchronize()
            epoch_time = start.elapsed_time(end)
            logs = {
                "time": epoch_time / 1000.0,  # ms to s
            }
        else:
            logs = {
                "time": -1.0
            }

        if use_wandb:
            wandb.log(logs, step=total_step)

        # p.step()

    if use_wandb:
        wandb.finish()


def start_training():
    # Handle wandb.init & stuff
    pass

    # TODO: Maybe a class Trainer or something, this could maybe
    # avoid the need for so many arguments in the functions


def main():
    parser = argparse.ArgumentParser("Neural MPM")
    parser.add_argument("--nowandb", action="store_true")
    parser.add_argument("--slurm", action="store_true")
    parser.add_argument("--epochs", help="Number of epochs", type=int, default=1)

    parser.add_argument(
        "--path",
        type=str,
        help="Path to dataset",
        required=True,
    )

    parser.add_argument(
        "--steps-per-call",
        help="Number of predictions per model call",
        type=int,
        default=1,
    )

    parser.add_argument(
        "--autoregressive-steps",
        help="Number of autoregressive steps during training",
        type=int,
        default=1,
    )

    parser.add_argument("--grid-size", help="Grid size", type=int, default=64)

    parser.add_argument("--batch-size", help="Batch size", type=int, default=1)

    parser.add_argument(
        "--passes-over-buffer",
        help="How many times to repeat a buffer",
        type=int,
        default=2,
    )

    parser.add_argument(
        "--sims-in-memory",
        help="How many times to repeat a buffer",
        type=int,
        default=2,
    )

    parser.add_argument(
        "--architecture",
        nargs="+",
        type=int,
        help="An integer list of architecture parameters",
    )

    parser.add_argument("--lr", help="Initial learning rate", type=float, default=1e-3)
    parser.add_argument(
        "--min-lr", help="Minimum learning rate", type=float, default=1e-6
    )
    parser.add_argument("--use-schedulers", action="store_true")

    args = parser.parse_args()

    use_wandb = not parser.parse_args().nowandb
    run_on_slurm = parser.parse_args().slurm

    if run_on_slurm:
        raise "Only run-locally supported for now!"
    else:
        # in_channels = 4
        # hidden_channels = [args.grid_size] * 4
        # modes = args.grid_size // 2
        # model = FNO(in_channels, hidden_channels, args.steps_per_call, modes)
        # if torch.cuda.is_available():
        #     torch.set_default_device("cuda")

        in_channels = 4
        if "WBC" in args.path:
            in_channels = 6

        architecture = args.architecture + [args.steps_per_call]
        factors = [2] * len(args.architecture)
        model = UNet(architecture, factors, in_channels=in_channels)
        if torch.cuda.is_available():
            model.to("cuda")
            model = torch.compile(model)

        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        schedulers = None
        if args.use_schedulers:
            warmup_end = 100
            cosine_start = 1000
            total_iters = int(1e5)
            warmup_end = min(warmup_end, cosine_start)
            linear_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=1e-2,
                end_factor=1.0,
                total_iters=warmup_end,
            )
            cos_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=total_iters - cosine_start,
                eta_min=args.min_lr,
                last_epoch=-1,
            )
            schedulers = {
                "linear": linear_scheduler,
                "cosine": cos_scheduler,
                "warmup_end": warmup_end,
                "cosine_start": cosine_start,
                "total_iters": total_iters,
            }

        datamanager = DataManager(
            args.path,
            batch_size=args.batch_size,
            dim=2,
            grid_size=args.grid_size,
            steps_per_call=args.steps_per_call,
            autoregressive_steps=args.autoregressive_steps,
            sims_in_memory=args.sims_in_memory,
        )

        if use_wandb:
            now = datetime.now()
            formatted_date_time = now.strftime("%m%d%H%M%S")
            wandb.init(
                project="debug_prod",
                entity="neuralmpm",
                name=f"{formatted_date_time}",
                config={
                    "steps_per_call": datamanager.steps_per_call,
                    "autoregressive_steps": datamanager.autoregressive_steps,
                    "grid_size": args.grid_size,
                    "interaction_radius": datamanager.interaction_radius,
                    "interp_fn": "constant",
                    "batch_size": args.batch_size,
                    "sims_in_memory": args.sims_in_memory,
                    "use_schedulers": args.use_schedulers,
                    "architecture": args.architecture,
                    "passes_over_buffer": args.passes_over_buffer,
                    "data": args.path,
                    "lr": args.lr,
                },
            )
            wandb.watch(model, log="all", log_freq=4)

        train(
            model=model,
            datamanager=datamanager,
            optimizer=optimizer,
            schedulers=schedulers,
            use_wandb=use_wandb,
            epochs=args.epochs,
            passes_over_buffer=args.passes_over_buffer,
            progress_bars=True,
        )


if __name__ == "__main__":
    main()
