import argparse
import torch
import numpy as np
import wandb
from data import Synthetic
from nn import WMLP, IPMLP, GMLP, EMLP
from torchcfm.conditional_flow_matching import *
from trainers import get_trainer
from main_2 import HybridFlowModel

# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main(args):
    # Create data handler
    data = Synthetic(args)
    
    # Model selection based on method argument
    if args.method == "cfm":
        model = WMLP(dim=args.dim, w=args.w, time_varying=True)
    elif args.method == "sb":
        model = WMLP(dim=args.dim, w=args.w, time_varying=True)
    elif args.method == "hybrid":
        model = HybridFlowModel(dim=args.dim, w=args.w)
    elif args.method == "gmlp":
        # Gradient-based approach
        model = GMLP(dim=args.dim, w=args.w)
    elif args.method == "emlp":
        # Energy-based approach
        model = EMLP(dim=args.dim, w=args.w)
    else:
        # Default to WMLP
        model = WMLP(dim=args.dim, w=args.w, time_varying=True)
    
    # Create flow matcher based on method
    if args.method == "cfm":
        flow_matcher = ConditionalFlowMatcher(sigma=0.0)
    elif args.method == "sb" or args.method == "hybrid":
        flow_matcher = SchrodingerBridgeConditionalFlowMatcher(sigma=args.sigma)
    else:
        # Default to Schrodinger Bridge if sigma > 0, otherwise CFM
        flow_matcher = (
            SchrodingerBridgeConditionalFlowMatcher(sigma=args.sigma) 
            if args.sigma > 0 
            else ConditionalFlowMatcher(sigma=0.0)
        )
    
    # Get appropriate trainer
    trainer = get_trainer(
        args.trainer_type, 
        model, 
        data, 
        args, 
        flow_matcher
    )
    
    # Train the model
    trainer.train()
    
    print("Training complete!")
    wandb.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train flow matching models")
    parser.add_argument("--dim", type=int, default=2, help="Dimension of the data")
    parser.add_argument("--components", type=int, default=4, help="Number of Gaussian medians")
    parser.add_argument("--sigma", type=float, default=0.01, help="Sigma value for SchrodingerBridge")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=500, help="Number of epochs for initial training")
    parser.add_argument("--re_epoch", type=int, default=100, help="Number of epochs for reflow iterations")
    parser.add_argument("--source_type", type=str, default="normal", help="Source distribution type: 'standard' or 'normal'")
    parser.add_argument("--target_type", type=str, default="normal", help="Target distribution type: 'gmm' or 'normal'")
    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
    parser.add_argument("--w", type=int, default=32, help="Width of neural networks")
    parser.add_argument("--method", type=str, default="hybrid", 
                      choices=["cfm", "sb", "hybrid", "gmlp", "emlp"], 
                      help="Flow matching method")
    parser.add_argument("--kreflow", type=int, default=4, help="Number of reflow iterations")
    parser.add_argument("--grad_clip", type=float, default=0.1, help="Gradient clipping value")
    parser.add_argument("--trainer_type", type=str, default="simple_reflow",
                       choices=["simple_reflow", "sbm", "stochastic_interpolant"],
                       help="Training methodology to use")
    parser.add_argument("--score_weight", type=float, default=1.0, 
                       help="Weight of score loss for stochastic interpolant")
    
    args = parser.parse_args()
    
    main(args)
