from model.model import MLP
import torch
from pathlib import Path


from rewards import (
    weighted_reward,
    weighted_reward_1d,
    weighted_reward_nd,
)
from steering import SourceTemperingSampler, FKSampler
from plotting import (
    plot_fk_vs_st_exploration,
    plot_fk_vs_st_histograms,
    plot_fk_vs_st_betas,
    plot_wasserstein,
)
from utils import marginal_prob_std, get_avg_wasserstein
import numpy as np
import argparse
import os

def set_seed(seed: int = 42) -> None:
    """Sets the seeds for reproducibility across all libraries."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # For multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

if __name__ == "__main__":
    # Use Cuda if possible
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    parser = argparse.ArgumentParser()


    # Call the function at the beginning of your script
    set_seed(39)

    # Parse args
    parser.add_argument(
        "--plot",
        help="Please indicate which plot you want",
        required=True,
        choices=["main-toy", "2d-toy", "beta-ablation", "wasserstein"],
    )
    args = parser.parse_args()

    # Make the main wasserstein plot
    if args.plot == "wasserstein":
        path = Path("results/wasserstein.npy")
        if not path.exists():
            # Note that this takes very long to generate
            print("Warning, this plot takes a long time to generate")
            print("Run it offline, it takes around 12 hours")

            # Load in hyperparameters and model
            mean_flow_steps = 50
            num_chains = 30
            batch_size = 1024
            data_dim = 1
            sigma = 4.0
            beta = 20  # Use a moderate temperature across each spike width
            model = MLP(data_dim, 512).to(device)
            state_dict = torch.load(
                "model/fk_failed_1d_2clouds.pt",
                map_location=device,
                weights_only=True,
            )
            st_sigma = marginal_prob_std(1, sigma).item()
            model.load_state_dict(state_dict)
            model.eval()
            reward_fn = weighted_reward_1d

            fks = []
            spts = []
            # Go over several base widths, Truncated for time constraints.
            for base_width in np.logspace(-6, -1, 50)[-15:]:
                fk, spt = get_avg_wasserstein(
                    model=model,
                    st_sampler=SourceTemperingSampler(
                        model=model,
                        reward_fn=lambda x: reward_fn(x, base_width),
                        beta=beta,
                        sigma=sigma,
                        n_chains=num_chains,
                        n_time_steps=mean_flow_steps,
                        device=device,
                        data_dim=data_dim,
                    ),
                    fk_sampler=FKSampler(
                        model=model,
                        reward_fn=lambda x: reward_fn(x, base_width),
                        sigma=sigma,
                        lmbda=beta,  # override lambda
                        n_particles=batch_size,
                        n_time_steps=mean_flow_steps,
                        resample_every=10,
                        data_dim=data_dim,
                        device=device,
                    ),
                    beta=beta,
                    batch_size=batch_size,
                    base_width=base_width,
                )
                fks.append(fk)
                spts.append(spt)
            # Save it so that we don't have to repeat this process
            np.save("results/wasserstein", [fks, spts])
        # Make the main plots
        plot_wasserstein()

    # This will be the ablation study on the temperature or beta
    if args.plot == "beta-ablation":
        # Load in hyperparameters and model
        mean_flow_steps = 50
        num_chains = 30
        batch_size = 1024
        data_dim = 1
        sigma = 4.0
        base_width = 0.01
        model = MLP(data_dim, 512).to(device)
        state_dict = torch.load(
            "model/fk_failed_1d_2clouds.pt",
            map_location=device,
            weights_only=True,
        )
        model.load_state_dict(state_dict)
        model.eval()
        # Run the plots
        plot_fk_vs_st_betas(
            model=model,
            base_width=base_width,
            reward_fn=weighted_reward_1d,
            sampler_params=dict(
                n_chains=num_chains,
                n_time_steps=mean_flow_steps,
                device=device,
                data_dim=data_dim,
            ),
            fk_params=dict(
                n_particles=batch_size,
                n_time_steps=mean_flow_steps,
                data_dim=data_dim,
                resample_every=10,
            ),
            marginal_prob_std_fn=marginal_prob_std,
            st_betas=(10.0, 100.0),
            batch_size=batch_size,
            base_sigma=sigma,
            save_path="results/imgs/",
        )

    # 2d Example:
    if args.plot == "2d-toy":
        # Hyper-parameters and model
        data_dim = 2
        mean_flow_steps = 50
        num_chains = 30
        batch_size = 1024
        sigma = 4.0
        base_width = 0.1
        model = MLP(data_dim, 512).to(device)
        state_dict = torch.load(
            "model/fk_failed_2d_4clouds.pt",
            map_location=device,
            weights_only=True,
        )
        model.load_state_dict(state_dict)
        model.eval()
        # This sigma is copmuted in the FK steering, but needs to be pre-computed for SPT
        st_sigma = marginal_prob_std(1, sigma).item()

        # MAke the plots
        plot_fk_vs_st_exploration(
            model=model,
            reward_fn=weighted_reward,
            st_uncond=SourceTemperingSampler(
                model=model,
                reward_fn=lambda x: weighted_reward(x, base_width),
                beta=0.0,
                sigma=st_sigma,
                n_chains=num_chains,
                n_time_steps=mean_flow_steps,
                device=device,
                data_dim=2,
            ),
            sampler_params=dict(
                n_chains=num_chains,
                n_time_steps=mean_flow_steps,
                device=device,
                data_dim=2,
            ),
            fk_sampler=FKSampler(
                model=model,
                reward_fn=lambda x: weighted_reward(x, base_width),
                sigma=sigma,
                n_particles=batch_size,
                n_time_steps=mean_flow_steps,
                lmbda=100.0,  # Only plot the one for high exploitation, they both behave similarly.
                device=device,
                resample_every=10,
            ),
            marginal_prob_std_fn=marginal_prob_std,
            st_betas=(10.0, 100.0),
            batch_size=batch_size,
            base_sigma=sigma,
            device=device,
            save_path="results/imgs/",
        )

    # The following is to generate the main toy example in the introduction
    if args.plot == "main-toy":
        # Seed was different for these experiments on accident.
        set_seed(46) # REMOVEME
        # Hyperparameters and model loading
        mean_flow_steps = 50
        num_chains = 30
        batch_size = 1024
        data_dim = 1
        sigma = 4.0
        beta = 20
        model = MLP(data_dim, 512).to(device)
        state_dict = torch.load(
            "model/fk_failed_1d_2clouds.pt",
            map_location=device,
            weights_only=True,
        )
        model.load_state_dict(state_dict)
        model.eval()

        # This sigma is copmuted in the FK steering, but needs to be pre-computed for SPT
        st_sigma = marginal_prob_std(1, sigma).item()
        for base_width in [1e-2]: 
            plot_fk_vs_st_histograms(
                model=model,
                reward_fn=weighted_reward_1d,
                fk_sampler=FKSampler(
                    model=model,
                    reward_fn=lambda x: weighted_reward_1d(x, base_width),
                    sigma=sigma,
                    n_particles=batch_size,
                    n_time_steps=mean_flow_steps,
                    resample_every=10,
                    data_dim=data_dim,
                    lmbda=beta,  # override lambda
                    device=device,
                ),
                st_sampler=SourceTemperingSampler(
                    model=model,
                    reward_fn=lambda x: weighted_reward_1d(x, base_width),
                    beta=beta,
                    sigma=st_sigma,
                    n_chains=num_chains,
                    n_time_steps=mean_flow_steps,
                    device=device,
                    data_dim=data_dim,
                ),
                base_width=base_width,
                batch_size=batch_size,
                beta=beta,
                device=device,
                save_path=f"results/imgs/{base_width}_",
            )
