import _pathfix
import hydra
import torch, torch.nn as nn
from omegaconf import DictConfig
import os
import matplotlib.pyplot as plt
from pathlib import Path
from hydra.utils import to_absolute_path
from torch.utils.data import Dataset, DataLoader
import numpy as np

from output_path import _output_path
from tgm.data.stock.data_stock import make_loaders
from tgm.train.trainer import Trainer
from tgm.train.callbacks import WandbCallback, PrintCallback, PlottingCallback, SavingCallback

from tgm.models.drift_diff_model import DriftDiffusionModel
from tgm.models.jump_model import JumpModel
from tgm.models.jump_model_uncoupled import JumpModelUncoupled
from tgm.models.jump_model_full_cov import JumpModelFullCov
from tgm.models.tfm_model import TfmModel

MODEL_REGISTRY = {
    "DriftDiffusionModel": DriftDiffusionModel,
    "JumpModel": JumpModel,
    "JumpModelUncoupled": JumpModelUncoupled,
    "JumpModelFullCov": JumpModelFullCov,
    "TfmModel": TfmModel
}

@hydra.main(config_path="../conf", config_name="config_stock_2d", version_base=None)
def main(cfg: DictConfig):
    # os.environ['CUDA_LAUNCH_BLOCKING']='1' # for debugging
    _seed_all(cfg.train.manual_seed)
    if cfg.data.T_sub != 101:
        cfg.train.no_bridges = cfg.data.T_sub 
    else:
        cfg.train.no_bridges = 100
    
    train_loader, val_sub, val_full, test_full = make_loaders(cfg.data, cfg.train)
    
    # plot and save test set
    # fig = _trajectory_plot_all_dims(test_full["x"], test_full["t"])
    # output_path = _output_path(cfg)
    # fig.savefig(os.path.join(output_path, "trajectories_test_set.png"), dpi=300)
    # plt.show(fig)
    
    model_class = MODEL_REGISTRY[cfg.model.model_name]
    model = model_class(cfg.model)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train.lr)
        
    # callbacks = [PrintCallback(), PlottingCallback(), SavingCallback(cfg), 
    #              WandbCallback(project = os.getenv("WANDB_PROJECT", "TGM-debug"), run_name = os.getenv("WANDB_RUN_NAME", "test_run"))]
    callbacks = [PrintCallback(), SavingCallback(cfg), 
                 WandbCallback(project = os.getenv("WANDB_PROJECT", "TGM-debug"), run_name = os.getenv("WANDB_RUN_NAME", "test_run"))]
    
    trainer = Trainer(cfg.train, model, optimizer, train_loader, val_sub, val_full, callbacks)
    
    best_mmd = trainer.train()
    
    print(f"\nMinimal mmd: {best_mmd}")
    
def _trajectory_plot_all_dims(trajectories, times, no_traj_to_plot = 100):
    times_cpu = times.detach().cpu().numpy()
    trajectories_cpu = trajectories.detach().cpu().numpy()
    
    dims = trajectories_cpu.shape[-1]
    fig, axes = plt.subplots(1, dims, figsize=(8*dims, 5))
    axes = [axes] if dims == 1 else axes
    for i in range(no_traj_to_plot):
        for j in range(dims):
            axes[j].plot(times_cpu[i,:], trajectories_cpu[i,:,j])

    return fig

def _seed_all(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
if __name__ == "__main__":
    main()
