import torch
from torch import Tensor
from evaluation import compute_similarities
from utils.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
)
import wandb
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar, RichModelSummary
from lightning.pytorch.loggers import WandbLogger
from data import GraphsDataModule
from lightning.pytorch.utilities import disable_possible_user_warnings
from models import DynamicGraphNetwork
from torch_geometric.data import Data, Batch
import time
from configs import NBODY, MD17, NBA
from tqdm import tqdm
import argparse


# FLAGS
DEBUGGING = True
PROCESS_DATA = True
PROCESS_GRAPHS = True
TRAIN = False
ONLY_LOAD = False

# CHOOSE DATASET
dataset = "NBody"
# dataset = "MD17"
# dataset = "NBA"

if dataset == "NBody":
    config = NBODY
elif dataset == "MD17":
    config = MD17
elif dataset == "NBA":
    config = NBA
config["do_preprocess"] = PROCESS_DATA
config["process_graphs"] = PROCESS_GRAPHS


class ConditionalFlowMatching(L.LightningModule):
    """Conditional Flow Matching model"""

    def __init__(self, model, config):
        super().__init__()
        self.config = config
        self.model = model
        self.window = config["time_window"]
        self.step_times = []
        self.test_priors = []
        self.test_predictions = []
        self.test_ground_truths = []
        if config["use_optimal_transport"]:
            self.cfm = ExactOptimalTransportConditionalFlowMatcher(sigma=config["fm_sigma"])
        else:
            self.cfm = ConditionalFlowMatcher(sigma=config["fm_sigma"])
        self.save_hyperparameters()

    def forward(self, t: Tensor, x: Tensor | Data) -> torch.Tensor:
        return self.model(x=x, t=t)

    def training_step(self, batch, batch_idx):
        # Sample location and conditional flow using CFM library
        t, xt_graph, ut, xt = self.model.sample_flow(self.cfm, batch)

        vt = self(t, xt_graph)  # Calculate vector field

        loss = self.model.compute_loss(t, vt, ut, xt_graph, xt)

        self.log("train/loss", loss, on_epoch=True, batch_size=self.config["batch_size"])

        # Plot learning rate
        self.log(
            "train/lr",
            self.trainer.optimizers[0].param_groups[0]["lr"],
            batch_size=self.config["batch_size"],
            on_epoch=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        # Sample location and conditional flow using CFM library
        t, xt_graph, ut, xt = self.model.sample_flow(self.cfm, batch)
        # Calculate vector field using model
        vt = self(t, xt_graph)

        val_loss = self.model.compute_loss(t, vt, ut, xt_graph, xt)

        self.log("val/loss", val_loss, on_epoch=True, batch_size=self.config["batch_size"])

        return val_loss

    def on_validation_epoch_end(self):
        # Run validation inference every 10 epochs
        if self.current_epoch % 5 == 0:
            batch = next(iter(self.trainer.datamodule.val_dataloader())).to("cuda")  # type: ignore
            self.model.eval()
            with torch.no_grad():
                rollout, _, _, _, x1 = self.model.generate_batch(
                    self, batch, steps=config["nfe"]
                )
                if "nba" in config["data_paths"].lower() and config["type"] in [0, 1]:
                    rollout = [r[1:] for r in rollout]
                    x1 = [x[1:] for x in x1]  # Remove ball
                x, x1 = torch.vstack(rollout), torch.vstack(x1)
                similarities = compute_similarities(x, x1, self.config)
                self.log("val/ade", similarities["ade"], batch_size=self.config["batch_size"])
                self.log("val/fde", similarities["fde"], batch_size=self.config["batch_size"])

        return None

    def test_step(self, batch, batch_idx):
        if config["use_best_of_20"]:
            num_runs = 20
        elif config["use_mean_of_5"]:
            num_runs = 5
        else:
            num_runs = 1

        for _ in tqdm(
            range(num_runs),
            desc=f"Generating {len(batch)} test samples, {num_runs} runs each",
        ):
            pred_trajectories, pred, all_trajs, priors, x1 = self.model.generate_batch(
                self, batch, steps=config["nfe"]
            )
            self.test_priors.extend(priors)
            self.test_predictions.extend(pred_trajectories)
            self.test_ground_truths.extend(x1)

    def on_test_epoch_end(self):
        plots, metrics = self.model.evaluate(
            self.test_priors,
            self.test_predictions,
            self.test_ground_truths,
            True,
            DEBUGGING,
        )

        if not DEBUGGING:
            for label, plot in plots.items():
                label = "test/" + label
                if isinstance(plot, str):
                    self.logger.experiment.log(  # type: ignore
                        {label: wandb.Video(plot, fps=3, format="gif")}
                    )
                self.logger.experiment.log({label: wandb.Image(plot)})  # type: ignore
            for label, metric in metrics.items():
                self.log("test/" + label, metric)  # type: ignore
        else:
            print("\nEvaluation metrics:\n")
            for key, value in metrics.items():
                print(f"{key}: {value}\n")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config["learning_rate"])

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, "min", factor=0.5, patience=30
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val/loss",
        }


def main():
    disable_possible_user_warnings()
    torch.set_float32_matmul_precision("high")

    parser = argparse.ArgumentParser()
    for flag, value in config.items():
        if not isinstance(value, dict):
            if isinstance(value, list):
                if len(value) == 0:
                    t = str
                else:
                    t = type(value[0])
                parser.add_argument("--" + flag, type=t, nargs="+")
            else:
                parser.add_argument("--" + flag, type=type(value))

    args = parser.parse_args()
    for arg in vars(args):
        if arg in config and vars(args)[arg] is not None:
            config[arg] = vars(args)[arg]
            print(f"Argument {arg} set to {config[arg]}")
        elif arg == "train" and vars(args)[arg] is not None:
            train = vars(args)[arg]
            print(f"Argument {arg} set to {train}")

    if not TRAIN:
        model = torch.load(config["weights_path"], weights_only=False)
        # Overwrite config.py with loaded config.py, adding missing config keys
        for k, v in config.items():
            if k not in model.config:
                model.config[k] = v
                model.model.config[k] = v
        for new_k, new_v in model.config.items():
            if new_k in ["weights_path"]:
                continue
            if (
                new_k == "radius_graph" and new_v is None
            ):  # Fix wrong setting in springs weights
                config[new_k] = 100
                continue
            config[new_k] = new_v

        model.model.config = config
        model.test_priors = []
        model.test_predictions = []
        model.test_ground_truths = []

    out_channels = config["dims"]
    mask_size = int(config["use_inpainting_mask"])
    in_channels = len(config["node_features"].keys()) + mask_size
    flow_model = DynamicGraphNetwork(
        in_channels=in_channels,
        in_edge_channels=len(config["edge_features"].keys()),
        hidden_channels=config["hidden_dim"],
        hidden_edge_channels=config["hidden_dim"],
        out_channels=out_channels,
        config=config,
    )
    if not ONLY_LOAD:
        data_module = GraphsDataModule(config, TRAIN)

    if TRAIN:
        model = ConditionalFlowMatching(flow_model, config).to(
            device="cuda" if torch.cuda.is_available() else "cpu"
        )

    if not DEBUGGING:  # Initialize wandb
        wandb_logger = WandbLogger(
            project="STFlow",
            config=config,
            log_model=True,
            save_dir="logs/wandb",
        )

    if TRAIN:
        run_name = (
            wandb_logger.experiment.name
            if not DEBUGGING and wandb_logger.experiment.name is not None
            else "debug"
        )
        trainer = L.Trainer(
            default_root_dir="logs",
            max_epochs=config["epochs"],
            devices=1,
            accelerator="gpu" if torch.cuda.is_available() else "cpu",
            callbacks=[
                ModelCheckpoint(
                    monitor="val/ade",
                    mode="min",
                    dirpath=config["model_path"],
                    save_last=True,
                    filename=f"{time.strftime('%Y-%m-%d-%H-%M-%S')}-{dataset}-{run_name}",
                ),
                RichProgressBar(),
                RichModelSummary(max_depth=-1),
            ],
            logger=wandb_logger if not DEBUGGING else None,
            enable_model_summary=False,
            num_sanity_val_steps=0,
            gradient_clip_val=0.5,
        )

        trainer.fit(model, datamodule=data_module)  # type: ignore

        name = f"{time.strftime('%Y-%m-%d-%H-%M-%S')}-{dataset}-{config['type']}"
        torch.save(model, f"final_weights/{name}_weights.pth")

        if not DEBUGGING:
            wandb_logger.log_metrics({"train_data_size": len(data_module.train_data)})

        trainer.test(model, datamodule=data_module)  # type: ignore

        if not DEBUGGING:
            wandb_logger.log_metrics({"test_data_size": len(data_module.test_data)})

    if not TRAIN and not ONLY_LOAD:
        trainer = L.Trainer(
            default_root_dir="logs",
            accelerator="gpu",
            logger=wandb_logger if not DEBUGGING else None,
            num_sanity_val_steps=0,
        )

        trainer.test(model, datamodule=data_module)  # type: ignore

        if not DEBUGGING:
            wandb_logger.log_metrics({"test_data_size": len(data_module.test_data)})

    if not DEBUGGING:
        wandb.finish()

    return model


if __name__ == "__main__":
    main()
