import numpy as np
import torch
import time
import tqdm
from pathlib import Path
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from functools import partial

from margflow.marginal_flow import MarginalFlow
from margflow.other_models.flowmatching import FlowMatching
from margflow.other_models.free_form_flows import FreeFormFlow
from margflow.other_models.normalizing_flow import NormalizingFlow

from fff.loss import volume_change_surrogate

import argparse

parser = argparse.ArgumentParser(description="")

parser.add_argument(
    "--model_name",
    type=str,
    default="marginal_flow",
    choices=["freeform_flow", "freeform_flow_exact", "normalizing_flow", "flow_matching", "marginal_flow"],
)

parser.add_argument(
    "--evaluation",
    type=str,
    default="sampling",
    choices=["sampling", "log_prob"],
)

args = parser.parse_args()

def main():
    script_path = Path(".")
    signature = Path("test")
    n_layers = 3
    hid_dim = 128
    device = "cuda"
    z_dim = 1_000
    initalize_margflow = partial(MarginalFlow, z_dim=z_dim, n_layers=n_layers, hid_dim=hid_dim, device=device, signature=script_path, script_path=signature, base_distribution="uniform")
    initialize_flowmatch = partial(FlowMatching, hid_dim=hid_dim, n_layers=n_layers, script_path=script_path, signature=signature, device=device)
    initalize_normflow_forward = partial(NormalizingFlow, direction="forward", n_layers=n_layers, signature=signature, script_path=script_path)
    initalize_normflow_reverse = partial(NormalizingFlow, direction="reverse", n_layers=n_layers, signature=signature, script_path=script_path)
    initialize_freeform_flow = partial(FreeFormFlow, z_dim=z_dim, n_layers=n_layers, hid_dim=hid_dim, signature=signature, script_path=script_path)
    init_dict = dict(marginal_flow=initalize_margflow,
                     # normalizing_flow_rev=initalize_normflow_reverse,
                     freeform_flow=initialize_freeform_flow,
                     freeform_flow_exact=initialize_freeform_flow,
                     flow_matching=initialize_flowmatch,
                     normalizing_flow=initalize_normflow_forward,
                     )
    
    dimensions = np.logspace(2, 6, 10).astype("int")
    dimensions = [int(d) for d in dimensions]
    n_samples = 100
    n_repetitions = 10
    n_warmups = 1
    runtime_results = []

    very_beginning = time.monotonic()
    max_time = 60 * 60 * 0.4
    print(f"Running runtime comparison for {args.model_name}")
    break_all = False
    
    with torch.no_grad():    
        for dim in tqdm.tqdm(dimensions):
            try:
                model = init_dict[args.model_name](x_dim=dim)
            except:
                break_all = True
            if break_all:
                break
            for _ in range(n_repetitions):
                if args.evaluation == "sampling":
                    # warm-up phase
                    # print("warmup...")
                    for _ in range(n_warmups):
                        samples = model.sample(n_samples)
                    
                    # actual runtime timing
                    # print("evaluation...")
                    torch.cuda.synchronize()
                    start_sampling = time.monotonic()
                    samples = model.sample(n_samples)
                    sampling_time = time.monotonic() - start_sampling
                    torch.cuda.synchronize()
                    start_eval = time.monotonic()
                    
                    runtime = dict(model_name=args.model_name,
                                   sampling_time=sampling_time,
                                   dim=dim)
                elif args.evaluation == "log_prob":
                    # warm-up phase
                    # print("warmup...")
                    x = torch.randn((n_samples, dim), device=device)
                    for _ in range(n_warmups):
                        try:
                            logp = model.log_prob(x, exact=True)
                        except:
                            logp = model.log_prob(x)

                    # actual runtime timing
                    torch.cuda.synchronize()
                    start_eval = time.monotonic()
                    logp = model.log_prob(x)
                    torch.cuda.synchronize()
                    eval_time = time.monotonic() - start_eval
                    
                    runtime = dict(model_name=args.model_name,
                                   eval_time=eval_time,
                                   dim=dim)
                    
                runtime_results.append(runtime)
                if time.monotonic() - very_beginning > max_time:
                    break_all = True
                    break
            if break_all:
                break
                    
    runtime_df = pd.DataFrame(runtime_results)
    runtime_df.to_csv(f"./runtime_results/{args.model_name}_runtime_{args.evaluation}.csv")


if __name__ == "__main__":
    main()

