import matplotlib.pyplot as plt
from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily, Independent, Normal
import abc
import torch 
import torchdyn
from torchdyn.core import NeuralODE
from torch import nn
from torchdyn.datasets import generate_moons
from torchcfm.conditional_flow_matching import *
from torchcfm.utils import *
import wandb
import copy 
import argparse

def get_args():
    parser = argparse.ArgumentParser(description="Train and evaluate CFM models.")

    parser.add_argument("--method", type=str, default="ipmlp", choices=["ipmlp", "gmlp", "emlp"],
                        help="Model architecture to use: 'ipmlp', 'gmlp', or 'emlp'.")
    parser.add_argument("--data", type=int, default=1, choices=[0, 1, 2],
                        help="Data distribution: 0=small std Gaussian, 1=anisotropic Gaussian, 2=isotropic Gaussian.")
    parser.add_argument("--lr", type=float, default=1e-4,
                        help="Learning rate for optimizer.")
    parser.add_argument("--epochs", type=int, default=400000,
                        help="Number of epochs for initial training.")
    parser.add_argument("--reflow_epochs", type=int, default=2000,
                        help="Number of epochs for reflow training.")
    parser.add_argument("--batch_size", type=int, default=100,
                        help="Batch size for reflow training.")
    parser.add_argument("--wandb_project", type=str, default="test_104_ccfm",
                        help="WandB project name.")
    parser.add_argument("--run_suffix", type=str, default="",
                        help="Optional suffix to append to the WandB run name.")

    return parser.parse_args()

args = get_args()

class IPMLP(nn.Module):
    def __init__(self, dim, w=64):
        super().__init__()
        self.NN = nn.Sequential(
            nn.Linear(dim + 1, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, dim),
        )
        
    def forward(self, t, x, *args, **kwargs):
        x = x.requires_grad_(True)
        t_ = t.view(-1, 1).expand(x.size(0), 1).to(x.dtype)
        s = self.NN(torch.cat([x, t_], dim=1))
        energy = (x.T @ s).sum(dim=1, keepdim=True)
        grad_x = torch.autograd.grad(
            outputs=energy,
            inputs=x,
            grad_outputs=torch.ones_like(energy),
            create_graph=True,
            retain_graph=True  # Critical fix for memory
        )[0]
        return grad_x

class GMLP(nn.Module):
    def __init__(self, dim, w=64):
        super().__init__()
        self.NN = nn.Sequential(
            nn.Linear(dim + 1, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, dim),
        )
        
    def forward(self, t, x, *args, **kwargs):
        x = x.requires_grad_(True)
        t_ = t.view(-1, 1).expand(x.size(0), 1).to(x.dtype)
        s = self.NN(torch.cat([x, t_], dim=1))
        energy = (s).sum(dim=1, keepdim=True)
        grad_x = torch.autograd.grad(
            outputs=energy,
            inputs=x,
            grad_outputs=torch.ones_like(energy),
            create_graph=True,
            retain_graph=True  # Critical fix for memory
        )[0]
        return grad_x

class EMLP(nn.Module):
    def __init__(self, dim, w=64):
        super().__init__()
        self.NN = nn.Sequential(
            nn.Linear(dim + 1, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, dim),
        )
        
    def forward(self, t, x, *args, **kwargs):
        x = x.requires_grad_(True)
        t_ = t.view(-1, 1).expand(x.size(0), 1).to(x.dtype)
        s = self.NN(torch.cat([x, t_], dim=1))
        energy = ((x - s) ** 2).sum(dim=1, keepdim=True)
        grad_x = torch.autograd.grad(
            outputs=energy,
            inputs=x,
            grad_outputs=torch.ones_like(energy),
            create_graph=True,
            retain_graph=True  # Critical fix for memory
        )[0]
        return grad_x

# Initialize device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.method == "gmlp":
    model = GMLP(dim=2).to(device)
    reflow_model = GMLP(dim=2).to(device)
elif args.method == "emlp":
    model = EMLP(dim=2).to(device)
    reflow_model = EMLP(dim=2).to(device)
else:
    model = IPMLP(dim=2).to(device)
    reflow_model = IPMLP(dim=2).to(device)  

if args.data == 1:
    mean = torch.tensor([5,5.])
    std = torch.tensor([[2.,0.],[0, 0.5]])
elif args.data == 2:
    mean = torch.tensor([0,0.])
    std = torch.tensor([[2.,0.],[0, 2.0]]) 
else:
    mean = torch.tensor([0,0.])
    std = torch.tensor([[0.5,0.],[0, 0.5]])   

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
FM = ConditionalFlowMatcher(sigma=0.0)

mean_str = "_".join([f"{m:.2f}" for m in mean])
std_str = "_".join([f"{s:.2f}" for s in std.diag()]) if std.shape[0] == std.shape[1] else "custom_std"
run_name = f"{args.method}_{args.data}"
wandb.init(project="test_105_ccfm", name=run_name)

def plot_trajectories(traj, title):
    """Helper function for consistent plotting"""
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(
        traj[:, :, 0], 
        traj[:, :, 1], 
        alpha=0.6, 
        c='gray',
        linewidth=1,
    )
    ax.scatter(traj[0, :, 0], traj[0, :, 1], s=20, c='black', label='Initial Samples')
    ax.scatter(traj[-1, :, 0], traj[-1, :, 1], s=20, edgecolor='red', label='Final Samples')
    ax.set_title(title)
    ax.set_xlabel("Dimension 1")
    ax.set_ylabel("Dimension 2")
    ax.legend()
    ax.grid(alpha=0.3)
    wandb.log({title: wandb.Image(fig)})

# Initial training phase
best_model_loss = float('inf')
best_model = None

for epoch in range(400000):
    optimizer.zero_grad()
    x0 = torch.randn((64, 2)).to(device)
    x1 = (torch.randn((64, 2)) @ std + mean).to(device)
    
    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
    vt = model(t, xt)
    loss = torch.mean((vt - ut) ** 2)
    
    wandb.log({"loss": loss})
    loss.backward()
    optimizer.step()

    if loss < best_model_loss:
        best_model_loss = loss
        best_model = copy.deepcopy(model)

# Initial evaluation
best_model.to("cpu")
node = NeuralODE(best_model, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)

traj = node.trajectory(
    torch.randn((64, 2)),
    t_span=torch.linspace(0, 1, 100)
)


wandb.log({
    "val_mean": traj[-1].mean().detach().numpy(),
    "val_std": traj[-1].std().detach().numpy()
})

plot_trajectories(traj.detach().numpy(), "Initial Trajectories")

# Reflow phase
reflow_optimizer = torch.optim.Adam(reflow_model.parameters(), lr=0.0001)

# Generate reflow dataset

x0 = torch.randn((10000, 2))
node = NeuralODE(best_model, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
traj = node.trajectory(x0, t_span=torch.linspace(0, 1, 100))
x1 = traj[-1].detach()

dataset = torch.utils.data.TensorDataset(x0, x1)
loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)

# Reflow training
for epoch in range(2000):
    for batch_x0, batch_x1 in loader:
        reflow_optimizer.zero_grad()
        batch_x0, batch_x1 = batch_x0.to(device), batch_x1.to(device)
        
        t, xt, ut = FM.sample_location_and_conditional_flow(batch_x0, batch_x1)
        vt = reflow_model(t, xt)
        loss = torch.mean((vt - ut) ** 2)
        
        loss.backward()
        reflow_optimizer.step()
        wandb.log({"reloss": loss})

# Reflow evaluation
reflow_model.to("cpu")
node = NeuralODE(reflow_model, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)

traj_refined = node.trajectory(
    torch.randn((64, 2)),
    t_span=torch.linspace(0, 1, 100)
)

plot_trajectories(traj_refined.detach().numpy(), "Reflowed Trajectories")
