import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import wandb
import math
import copy
from torchcfm.conditional_flow_matching import *
from data import Synthetic
from nn import WMLP
import torch.nn as nn
from validation_utils import evaluate_model

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class HybridFlowModel(torch.nn.Module):
    def __init__(self, dim=2, w=64):
        super().__init__()
        self.velocity_net = WMLP(dim=dim, w=w, time_varying=True)
        self.score_net = WMLP(dim=dim, w=w, time_varying=True)
        self.switch_to_cfm = False
        
    def forward(self, t, x, *args, **kwargs):
        if self.switch_to_cfm:
            return self.velocity_net(t, x)
        else:
            return self.velocity_net(t, x), self.score_net(t, x)

class BFromVS(torch.nn.Module):
    """Class for turning a velocity model v and a score model s into a drift model b."""
    def __init__(self, v, s, sigma=1.0) -> None:
        super(BFromVS, self).__init__()
        self.v = v
        self.s = s
        self.sigma = sigma
        self.gg_dot = lambda t: (self.sigma**2/2)*(1-2*t)
        
    def forward(self, t, x, *args, **kargs):
        v = self.v(t, x)
        s = self.s(t, x)
        return v - self.gg_dot(t)*s

def train_epoch(model, optimizer, FM, args, device, epoch, x0, x1, reflow_iteration=0, use_hybrid=False):
    optimizer.zero_grad()
    x0 = x0.to(device)
    x1 = x1.to(device)
    model = model.to(device)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
    
    if use_hybrid:
        # For HybridFlowModel
        vt, st = model(t, xt)
        # Calculate loss
        velocity_loss = torch.mean((vt - ut) ** 2)
        
        # Compute score target using the conditional score formula
        score_target = -1 * (x1 - x0) / FM.sigma
        score_loss = torch.mean((st - score_target) ** 2)
        
        # Combined loss
        loss = velocity_loss + score_loss
        # Log metrics
        wandb.log({
            "v_loss": velocity_loss.item(),
            "s_loss": score_loss.item(),
            "epoch": epoch + reflow_iteration * args.re_epoch,
        })
    else:
        # Original single network model
        vt = model(t, xt)
        # Calculate loss
        if isinstance(vt, tuple):
            # Handle tule output (velocity, score)
            vt = vt[0]  # Use just the velocity component
        loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    if args.grad_clip is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    
    optimizer.step()
    
    # Log metrics
    wandb.log({
        "loss": loss.item(),
        "epoch": epoch + reflow_iteration * args.re_epoch,
    })
    
    return loss.item()

def main(args):
    data = Synthetic(args)
    
    # Initialize data pairs for first reflow iteration
    data.update_pairs(0)
    
    # Initialize dataloader
    dataloader = torch.utils.data.DataLoader(
        data.all_pairs, batch_size=args.batch_size, shuffle=True, num_workers=2
    )
    
    # Initialize model and optimizer
    model = HybridFlowModel(dim=args.dim, w=args.w)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    
    # Initialize flow matcher with high sigma
    FM = SchrodingerBridgeConditionalFlowMatcher(sigma=args.sigma)
    use_hybrid = True
    
    # Setup WandB
    wandb.init(project=f"test_flow_reflow_{args.dim}",
              name=f"{args.method}_dim_{args.dim}_w_{args.w}_sigma_{args.sigma}_kreflow_{args.kreflow}_grad_clip_{args.grad_clip}",
              config=args)
    
    data.update_pairs(0)
    models = []
    # For each reflow iteration
    for k in range(args.kreflow + 1):
        # Update sigma for this reflow iteration (divide by 10)
        current_sigma = args.sigma / (10 ** k)
        FM.sigma = current_sigma
        
        # Switch to pure CFM at last reflow
        if k == args.kreflow:
            print(f"Switching to pure CFM at reflow {k}")
            FM = ConditionalFlowMatcher(sigma=0.0)
            current_sigma = 0.0 

            use_hybrid = False
            model.switch_to_cfm = True
        
        print(f"Reflow iteration {k}, sigma = {current_sigma}")
        wandb.log({"sigma": current_sigma, "reflow_iteration": k})
        
        # Training epochs
        epochs = args.num_epochs if k == 0 else args.re_epoch
        
        epoch = 0
        while epoch < epochs:
            for batch_idx, (x0_batch, x1_batch) in enumerate(dataloader):
                train_epoch(model, optimizer, FM, args, device, epoch, x0_batch, x1_batch, 
                           reflow_iteration=k, use_hybrid=use_hybrid)
            
                epoch += 1

                if epoch > epochs:
                    break


                # Periodically evaluate
                if epoch % 10 == 0:
                    if use_hybrid:
                        # Create drift model from hybrid model for evaluation
                        model_cpu = model.cpu()
                        drift_model = BFromVS(
                            v=lambda t,x: model_cpu(t,x)[0],
                            s=lambda t,x: model_cpu(t,x)[1],
                            sigma=current_sigma
                        )
                        evaluate_model(drift_model, args, data, reflow_iteration=k, epoch=epoch)
                    else:
                        evaluate_model(model, args, data, reflow_iteration=k, epoch=epoch)

        if use_hybrid:
            model_cpu = model.cpu()
            drift_model = BFromVS(
                v=lambda t,x: model_cpu(t,x)[0],
                s=lambda t,x: model_cpu(t,x)[1],
                sigma=current_sigma
            )
            data.forward(drift_model, args)
            models.append(copy.deepcopy(drift_model))
            # Final evaluation
            evaluate_model(drift_model, args, data, reflow_iteration=k)
        else:
            data.forward(model, args)
            models.append(copy.deepcopy(model))
            # Final evaluation
            evaluate_model(model, args, data, reflow_iteration=k)
        
    
    plot_combined_trajectories(models, args, data)

    # Final evaluation
    print("Training complete!")
    wandb.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train with sigma scheduling")
    parser.add_argument("--dim", type=int, default=2, help="Dimension of the data")
    parser.add_argument("--components", type=int, default=4, help="Number of Gaussian medians")
    parser.add_argument("--sigma", type=float, default=0.1, help="Starting sigma value")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=500, help="Number of epochs for initial training")
    parser.add_argument("--re_epoch", type=int, default=100, help="Number of epochs for reflow iterations")
    parser.add_argument("--source_type", type=str, default="standard", help="Source distribution type")
    parser.add_argument("--target_type", type=str, default="gmm", help="Target distribution type")
    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
    parser.add_argument("--w", type=int, default=32, help="Width of neural networks")
    parser.add_argument("--kreflow", type=int, default=4, help="Number of reflow iterations")
    parser.add_argument("--grad_clip", type=float, default=0.1, help="Gradient clipping value")
    parser.add_argument("--method", type=str, default="ultra_hybrid")
    args = parser.parse_args()
    main(args)