import argparse
import wandb
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchdyn.core import NeuralODE
from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *

from data.gaussians import sample_gaussian
from utils.var_helpers import plot_variance_heatmaps, sample_points

# Set up argument parser
parser = argparse.ArgumentParser(description='Run Gaussian Flow Matching experiment')
parser.add_argument('--s', type=float, help='s value for the experiment')
args = parser.parse_args()
# Initialize wandb
wandb.init(project="fm_vec", config=args)

# Set up directories and parameters
savedir = "models/2gaussian"
os.makedirs(savedir, exist_ok=True)

sigma = 0.1
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())

FM = ConditionalFlowMatcher(sigma=sigma)

s = args.s

x0_fixed = sample_gaussian(100, 0, 1)
x1_fixed = sample_gaussian(100, 5, s)
t_eval = torch.linspace(0, 1, 10)
vector_field_outputs = {t.item(): [] for t in t_eval}

# Primary vector field training
saved_models = []

for k in range(10000):
    optimizer.zero_grad()
    x0 = sample_gaussian(100, 0, 1)
    x1 = sample_gaussian(100, 5, s)
    t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True)
    vt = model(torch.cat([xt, t[:, None]], dim=-1))
    loss = torch.mean((vt - ut) ** 2)
    loss.backward()
    optimizer.step()

    if (k + 1) > 9950:
        with torch.no_grad():
            saved_models.append(model.state_dict())  # Append primary model state dict
            for t_fixed in t_eval:
                x = sample_points(100, 1)
                inputs = torch.cat([x, t_fixed.expand(x0_fixed.size(0), 1)], dim=-1)
                vt_fixed = model(inputs)
                vector_field_outputs[t_fixed.item()].append(vt_fixed.cpu().numpy())

    wandb.log({"primary_vf_loss": loss.item(), "primary_vf_step": k})

# Compute and log variances for primary vector field
plot_variance_heatmaps(vector_field_outputs, wandb, t_eval, saved_models)

primary_variances = []
for t, outputs in vector_field_outputs.items():
    outputs_array = np.array(outputs)  # Shape: (num_saves, batch_size, dim)
    var_over_time = np.var(outputs_array, axis=0)  # Shape: (batch_size, dim)
    mean = np.mean(var_over_time)  # Variance for each vector field dimension
    wandb.log({
        "variance_1": mean,
        "time_step": t,
        "vector_field": "primary"
    })
    primary_variances.append(mean)

# Simulate forward using the primary vector field
node = NeuralODE(torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
with torch.no_grad():
    simulated_x1 = node.trajectory(x0, t_span=torch.linspace(0, 1, 100))[-1]

# Secondary vector field training
model2 = MLP(dim=dim, time_varying=True)
optimizer2 = torch.optim.Adam(model2.parameters())
vector_field_outputs2 = {t.item(): [] for t in t_eval}

saved_models2 = []  # Initialize a separate list for secondary models
for k in range(10000):
    optimizer2.zero_grad()
    t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, simulated_x1, return_noise=True)
    vt2 = model2(torch.cat([xt, t[:, None]], dim=-1))
    loss2 = torch.mean((vt2 - ut) ** 2)
    loss2.backward()
    optimizer2.step()

    if (k + 1) > 9950:
        with torch.no_grad():
            saved_models2.append(model2.state_dict())  # Append secondary model state dict
            for t_fixed in t_eval:
                x = sample_points(100, 1)
                inputs = torch.cat([x, t_fixed.expand(x0_fixed.size(0), 1)], dim=-1)
                vt_fixed = model2(inputs)  # Use model2 for secondary field
                vector_field_outputs2[t_fixed.item()].append(vt_fixed.cpu().numpy())

    wandb.log({"secondary_vf_loss": loss2.item(), "secondary_vf_step": k})

# Compute and log variances for secondary vector field
plot_variance_heatmaps(vector_field_outputs2, wandb, t_eval, saved_models2)

secondary_variances = []
for t, outputs in vector_field_outputs2.items():
    outputs_array = np.array(outputs)  # Shape: (num_saves, batch_size, dim)
    var_over_time = np.var(outputs_array, axis=0)  # Shape: (batch_size, dim)
    mean = np.mean(var_over_time)  # Variance for each vector field dimension
    secondary_variances.append(mean)
    wandb.log({
        "variance_2": mean,
        "time_step": t,
        "vector_field": "secondary"
    })

# Log final variance plot
fig, ax = plt.subplots()
ax.plot(t_eval.numpy(), primary_variances, label="Primary Vector Field")
ax.plot(t_eval.numpy(), secondary_variances, label="Secondary Vector Field")
ax.set_xlabel("Time")
ax.set_ylabel("Variance")
ax.set_title(f"Variance Over Time (s={s})")
ax.legend()
wandb.log({"variance_plot": wandb.Image(fig)})
plt.close(fig)

wandb.finish()
