import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import math
import copy
from torchdyn.core import NeuralODE
from torchcfm.conditional_flow_matching import *
from data import Synthetic, SourceDistribution, TargetDistribution
from nn import WMLP
import torch.nn as nn
from validation_utils import evaluate_model, plot_combined_trajectories

# Set random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BidirectionalCFMTrainer:
    """
    Trainer that implements bidirectional CFM with reflow.
    Trains both forward and backward networks, then performs reflow.
    """
    
    def __init__(self, args):
        """
        Initialize the trainer.
        
        Args:
            args: Command-line arguments
        """
        self.args = args
        self.device = device
        self.data = Synthetic(args)
        
        # Initialize forward and backward models
        self.forward_model = WMLP(dim=args.dim, w=args.w, time_varying=True)
        self.backward_model = WMLP(dim=args.dim, w=args.w, time_varying=True)
        
        # Initialize optimizers
        self.forward_optimizer = torch.optim.Adam(self.forward_model.parameters(), lr=args.lr)
        self.backward_optimizer = torch.optim.Adam(self.backward_model.parameters(), lr=args.lr)
        
        # Initialize flow matchers for both directions
        # Both use the same forward conditional flow matching technique
        # but the backward model will go from target→source instead of source→target
        self.forward_fm = ConditionalFlowMatcher(sigma=args.sigma)
        self.backward_fm = ConditionalFlowMatcher(sigma=args.sigma)
        
        # Track models from each reflow iteration
        self.forward_models = []
        self.backward_models = []
        
        # Initialize WandB
        if not args.no_wandb:
            wandb.init(
                project=args.wandb_project,
                name=f"bidirectional_cfm_dim_{args.dim}_w_{args.w}_sigma_{args.sigma}_kreflow_{args.kreflow}",
                config=vars(args)
            )
    
    def train_epoch(self, model, optimizer, fm, x0, x1, epoch, direction="forward", reflow_iteration=0):
        optimizer.zero_grad()
        x0 = x0.to(self.device)
        x1 = x1.to(self.device)
        model = model.to(self.device)
        
        t, xt, ut = fm.sample_location_and_conditional_flow(x0, x1)
        vt = model(t, xt)
        loss = torch.mean((vt - ut) ** 2)
        
        loss.backward()
        optimizer.step()
        
        if not self.args.no_wandb:
            wandb.log({
                f"{direction}_loss": loss.item(),
                "epoch": epoch,
                "reflow_iteration": reflow_iteration,
                "direction": direction
            })
        
        return loss.item()
    
    def train_bidirectional(self, reflow_iteration=0):
        print(f"=== Starting bidirectional training for reflow iteration {reflow_iteration} ===")
        
        # Initialize data pairs for this reflow iteration
        self.data.update_pairs()
        
        # For reflow iterations > 0, use different datasets for forward and backward
        if reflow_iteration > 0:
            # Forward model uses backward-integrated pairs
            if hasattr(self.data, 'all_pairs_backward') and self.data.all_pairs_backward is not None:
                forward_dataloader = torch.utils.data.DataLoader(
                    self.data.all_pairs_backward,  # Backward-integrated pairs for forward model
                    batch_size=self.args.batch_size, 
                    shuffle=True
                )
            else:
                forward_dataloader = None
                print("Warning: all_pairs_backward not available. Using standard pairs for forward model.")
            
            # Backward model uses forward-integrated pairs
            if hasattr(self.data, 'all_pairs_forward') and self.data.all_pairs_forward is not None:
                backward_dataloader = torch.utils.data.DataLoader(
                    self.data.all_pairs_forward,  # Forward-integrated pairs for backward model
                    batch_size=self.args.batch_size, 
                    shuffle=True
                )
            else:
                backward_dataloader = None
                print("Warning: all_pairs_forward not available. Using standard pairs for backward model.")
        
        # Standard dataloader for first iteration or if specialized pairs aren't available
        standard_dataloader = torch.utils.data.DataLoader(
            self.data.all_pairs, 
            batch_size=self.args.batch_size, 
            shuffle=True
        )
        
        # Training epochs for this reflow iteration
        epochs = self.args.num_epochs if reflow_iteration == 0 else self.args.re_epoch
        
        # Train for specified number of epochs
        for epoch in range(epochs):
            epoch_forward_loss = 0.0
            epoch_backward_loss = 0.0
            forward_batch_count = 0
            backward_batch_count = 0
            
            # Train forward model
            if reflow_iteration > 0 and forward_dataloader is not None:
                # Use backward-integrated pairs for forward model in reflow iterations
                # The backward model went from target→source, so these pairs need proper handling
                for x0_batch, x1_batch in forward_dataloader:  # x0_batch is backward-integrated (source-like), x1_batch is original target
                    # Note: These are already correctly oriented for forward training (source→target)
                    # The backward integration produced good source initialization points
                    forward_loss = self.train_epoch(
                        self.forward_model, 
                        self.forward_optimizer, 
                        self.forward_fm, 
                        x0_batch,  # Backward-integrated samples (source distribution)
                        x1_batch,  # Target distribution samples
                        epoch, 
                        direction="forward", 
                        reflow_iteration=reflow_iteration
                    )
                    epoch_forward_loss += forward_loss
                    forward_batch_count += 1
            else:
                # Use standard pairs for forward model in first iteration
                for x0_batch, x1_batch in standard_dataloader:
                    forward_loss = self.train_epoch(
                        self.forward_model, 
                        self.forward_optimizer, 
                        self.forward_fm, 
                        x0_batch,  # Source distribution
                        x1_batch,  # Target distribution
                        epoch, 
                        direction="forward", 
                        reflow_iteration=reflow_iteration
                    )
                    epoch_forward_loss += forward_loss
                    forward_batch_count += 1
            
            # Train backward model (which is actually a forward model but target→source)
            if reflow_iteration > 0 and backward_dataloader is not None:
                # Use forward-integrated pairs for backward model in reflow iterations
                for x0_batch, x1_batch in backward_dataloader:
                    # Note: we're treating x1_batch as the source and x0_batch as the target for the backward model
                    # (opposite of forward model, but still a forward process from target→source)
                    backward_loss = self.train_epoch(
                        self.backward_model, 
                        self.backward_optimizer, 
                        self.backward_fm, 
                        x1_batch,  # Target distribution (becomes source for backward model)
                        x0_batch,  # Source distribution (becomes target for backward model)
                        epoch, 
                        direction="backward", 
                        reflow_iteration=reflow_iteration
                    )
                    epoch_backward_loss += backward_loss
                    backward_batch_count += 1
            else:
                # Use standard pairs for backward model in first iteration
                for x0_batch, x1_batch in standard_dataloader:
                    # For backward model, we swap the source and target
                    backward_loss = self.train_epoch(
                        self.backward_model, 
                        self.backward_optimizer, 
                        self.backward_fm, 
                        x1_batch,  # Target distribution (becomes source for backward model)
                        x0_batch,  # Source distribution (becomes target for backward model)
                        epoch, 
                        direction="backward", 
                        reflow_iteration=reflow_iteration
                    )
                    epoch_backward_loss += backward_loss
                    backward_batch_count += 1
            
            # Calculate average loss for the epoch
            avg_forward_loss = epoch_forward_loss / max(1, forward_batch_count)
            avg_backward_loss = epoch_backward_loss / max(1, backward_batch_count)
            
            # Print progress
            if epoch % 50 == 0 or epoch == epochs - 1:
                print(f"Epoch {epoch}/{epochs}, Reflow {reflow_iteration}: Forward Loss: {avg_forward_loss:.6f}, Backward Loss: {avg_backward_loss:.6f}")
                # Evaluate models periodically
                self.evaluate_models(reflow_iteration, epoch)
    
    def train(self):
        """
        Train models with reflow methodology.
        """
        # For each reflow iteration
        for k in range(self.args.kreflow + 1):
            # Train both forward and backward models
            self.train_bidirectional(reflow_iteration=k)
            
            # Store models after each reflow iteration
            self.forward_models.append(copy.deepcopy(self.forward_model))
            self.backward_models.append(copy.deepcopy(self.backward_model))
            
            # Apply integration processes for next iteration (if not the last one)
            if k < self.args.kreflow:
                # Forward integration (source→target) to create deterministic pairs for backward model
                print(f"=== Applying forward integration (source→target) for reflow iteration {k+1} ===")
                self.data.forward(self.forward_model, self.args)
                
                # Target→Source integration to create deterministic pairs for forward model
                # This is a forward flow but from target to source (swapped endpoints)
                print(f"=== Applying target→source integration for reflow iteration {k+1} ===")
                self.data.backward(self.backward_model, self.args)
                
                # Reset model optimizers to help with the new data distribution
                if k < self.args.kreflow:
                    print("Resetting optimizers for next reflow iteration")
                    self.forward_optimizer = torch.optim.Adam(self.forward_model.parameters(), lr=self.args.lr)
                    self.backward_optimizer = torch.optim.Adam(self.backward_model.parameters(), lr=self.args.lr)
                
                # Update data pairs for next iteration
                self.data.update_pairs()
        
        # Plot results at the end
        self.plot_results()
        
        # Finish WandB logging
        if not self.args.no_wandb:
            wandb.finish()
    
    def evaluate_models(self, reflow_iteration=0, epoch=0):
        """
        Evaluate forward and backward models.
        
        Args:
            reflow_iteration: Current reflow iteration
            epoch: Current epoch
        """
        # Evaluate forward model
        evaluate_model(
            self.forward_model, 
            self.args, 
            self.data, 
            direction="forward", 
            reflow_iteration=reflow_iteration, 
            epoch=epoch
        )
        
        # Evaluate backward model
        evaluate_model(
            self.backward_model, 
            self.args, 
            self.data, 
            direction="backward", 
            reflow_iteration=reflow_iteration, 
            epoch=epoch
        )
    
    def plot_results(self):
        """
        Plot results after training.
        """
        # Plot combined trajectories for forward models
        plot_combined_trajectories(self.forward_models, self.args, self.data)
        
        # Generate samples with the final forward model
        self.plot_integrated_samples(self.forward_models[-1], "final_forward_integration", direction="forward")
        
        # Generate samples with the final backward model
        self.plot_integrated_samples(self.backward_models[-1], "final_backward_integration", direction="backward")
    
    def plot_integrated_samples(self, model, plot_name, direction="forward"):
        """
        Plot samples and their integrated values.
        
        Args:
            model: Model to use for integration
            plot_name: Name for the plot
            direction: Direction of integration ("forward" or "backward")
        """
        model_cpu = model.cpu()
        
        # Get appropriate samples based on direction
        if direction == "forward":
            # Forward: start with source, integrate to target
            initial_samples = self.data.source_distr.test_values.cpu()
            target_distribution = self.data.target_distr
            t_span = torch.linspace(0, 1, 100)
            title = "Source→Target Integration"
            initial_label = "Initial samples (source)"
            final_label = "Integrated samples (target)"
        else:
            # Backward model: actually a forward flow from target to source
            initial_samples = self.data.target_distr.test_values.cpu()
            target_distribution = self.data.source_distr
            t_span = torch.linspace(0, 1, 100)  # Forward flow from target→source
            title = "Target→Source Integration"
            initial_label = "Initial samples (target)"
            final_label = "Integrated samples (source)"
        
        # Create NeuralODE for integration
        node = NeuralODE(
            model_cpu,
            solver="dopri5",
            sensitivity="adjoint",
            atol=1e-3,
            rtol=1e-3
        )
        
        # Integrate samples
        with torch.no_grad():
            traj = node.trajectory(
                initial_samples,
                t_span=t_span
            )
        
        # Plot only for 2D case
        if self.args.dim == 2:
            plt.figure(figsize=(10, 8))
            
            # Plot initial samples
            plt.scatter(
                traj[0, :, 0].cpu().numpy(), 
                traj[0, :, 1].cpu().numpy(), 
                s=20, 
                c='blue', 
                alpha=0.7,
                label=initial_label
            )
            
            # Plot integrated samples
            plt.scatter(
                traj[-1, :, 0].cpu().numpy(), 
                traj[-1, :, 1].cpu().numpy(), 
                s=20, 
                c='red', 
                alpha=0.7,
                label=final_label
            )
            
            # Plot target distribution background
            target_samples = target_distribution.test_values.cpu().numpy()
            sns.kdeplot(
                x=target_samples[:, 0],
                y=target_samples[:, 1],
                cmap='Greys',
                fill=True,
                alpha=0.3,
                zorder=0
            )
            
            # Plot trajectory paths for a few samples
            num_paths = min(5, initial_samples.shape[0])
            for i in range(num_paths):
                plt.plot(
                    traj[:, i, 0].cpu().numpy(), 
                    traj[:, i, 1].cpu().numpy(), 
                    'g-', 
                    alpha=0.3
                )
            
            plt.title(title)
            plt.xlabel("Dimension 1")
            plt.ylabel("Dimension 2")
            plt.legend()
            plt.grid(alpha=0.3)
            
            # Log or save the plot
            if not self.args.no_wandb:
                wandb.log({plot_name: wandb.Image(plt)})
            plt.savefig(f"{plot_name}.png")
            plt.close()

def parse_args():
    parser = argparse.ArgumentParser(description="Train bidirectional CFM with reflow")
    
    # Model parameters
    parser.add_argument("--dim", type=int, default=2, help="Data dimension")
    parser.add_argument("--w", type=int, default=64, help="Width of network layers")
    
    # Training parameters
    parser.add_argument("--sigma", type=float, default=0.0, help="Sigma value for flow matcher")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--num_epochs", type=int, default=10000, help="Number of epochs for first iteration")
    parser.add_argument("--re_epoch", type=int, default=500, help="Number of epochs for reflow iterations")
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
    parser.add_argument("--grad_clip", type=float, default=None, help="Gradient clipping value")
    
    # Reflow parameters
    parser.add_argument("--kreflow", type=int, default=1, help="Number of reflow iterations")
    
    # Distributions
    parser.add_argument("--source_type", type=str, default="standard", help="Source distribution type (standard, normal)")
    parser.add_argument("--target_type", type=str, default="normal", help="Target distribution type (normal, gmm)")
    
    # Method parameters
    parser.add_argument("--method", type=str, default="cfm", help="Flow matching method")
    
    # Logging and seed parameters
    parser.add_argument("--no_wandb", action="store_true", help="Disable WandB logging")
    parser.add_argument("--wandb_project", type=str, default="bidirectional_cfm_reflow", help="WandB project name")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    return parser.parse_args()

def main():
    # Parse command-line arguments
    args = parse_args()

    # Set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # Print configuration
    print(f"=== Bidirectional CFM with Reflow ===")
    print(f"Dimensions: {args.dim}, Width: {args.w}, Sigma: {args.sigma}")
    print(f"Reflow iterations: {args.kreflow}")
    print(f"Device: {device}")
    
    # Create trainer and train
    trainer = BidirectionalCFMTrainer(args)
    trainer.train()

if __name__ == "__main__":
    main()
