
import sys
import os
import time
import math

from absl import app
from absl import flags

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons


# Matplotlib Agg      
plt.rcParams['agg.path.chunksize'] = 10000 

current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.join(current_dir, '../../') # /uot-wfm  
sys.path.insert(0, os.path.abspath(root_dir))

import torchcfm
from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *
from torchcfm.utils import sample_unbalanced_kgaussians  #  import 
from torchcfm.optimal_transport import OTPlanSampler

# Import logging utilities
try:
    from torchcfm.logging_utils import setup_logging, get_logger
except ImportError:
    # Fallback if logging utils not available
    def setup_logging(log_dir="logs", log_level="INFO", experiment_name=None, console_output=True):
        import logging
        return logging.getLogger("torchcfm")
    
    def get_logger(name="torchcfm"):
        import logging
        return logging.getLogger(name)


def sample_conditional_pt(x0, x1, t, sigma):
    """
    Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    xt : Tensor, shape (bs, *dim)

    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Anonymous et al.
    """
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = torch.randn_like(x0)
    return mu_t + sigma * epsilon

def compute_conditional_vector_field(x0, x1):
    """
    Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the target minibatch

    Returns
    -------
    ut : conditional vector field ut(x1|x0) = x1 - x0

    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Anonymous et al.
    """
    return x1 - x0

def parse_list_of_tuples(arg_string):
    # Parses a string like "1,2;3,4" into [(1,2), (3,4)]
    return [tuple(map(float, p.split(','))) for p in arg_string.split(';') if p]

def parse_list_of_floats(arg_string):
    # Parses a string like "1.0,2.0" into [1.0, 2.0]
    return [float(f) for f in arg_string.split(',') if f]

def parse_square(arg_string):
    # Parses a string like "cx,cy,w,h,n" into (cx, cy, w, h, n)
    parts = [p.strip() for p in arg_string.split(',') if p.strip() != '']
    if len(parts) != 5:
        raise ValueError(f"Invalid square spec: '{arg_string}'. Expected format 'cx,cy,w,h,n'")
    cx, cy, w, h = map(float, parts[:4])
    n = int(float(parts[4]))
    if w <= 0 or h <= 0 or n <= 0:
        raise ValueError(f"Square spec must have positive w,h,n: '{arg_string}'")
    return (cx, cy, w, h, n)

def parse_square_list(arg_string):
    # Parses a string like "cx,cy,w,h,n;..." into list of tuples
    if arg_string.strip() == '':
        return []
    return [parse_square(spec) for spec in arg_string.split(';') if spec.strip() != '']

def sample_uniform_square(cx, cy, w, h, n, device=None, dtype=None):
    u = torch.rand(n, 2, device=device, dtype=dtype)
    x = torch.stack([
        (cx - w / 2.0) + u[:, 0] * w,
        (cy - h / 2.0) + u[:, 1] * h
    ], dim=1)
    return x

def sample_uniform_multi_squares(specs, device=None, dtype=None):
    xs = []
    for (cx, cy, w, h, n) in specs:
        xs.append(sample_uniform_square(cx, cy, w, h, n, device=device, dtype=dtype))
    if len(xs) == 0:
        return torch.empty(0, 2, device=device, dtype=dtype)
    return torch.cat(xs, dim=0)

# Define flags
flags.DEFINE_string('plan', 'icfm', 'OT plan type: icfm, exact, uot_fm, uot_wfm')
flags.DEFINE_string('weight_type', 'none', 'Weight type: inv_tnu, none')
flags.DEFINE_integer('epochs', 1000, 'Number of training epochs')
flags.DEFINE_integer('visual_interval', 50, 'Interval to save visualizations (epochs)')
flags.DEFINE_integer('batch_size', 256, 'Batch size for training')
flags.DEFINE_float('weight_power_factor', 1.0, 'Power factor for inverse tnu weighting')
flags.DEFINE_string('source_centers', '0,0', 'Source centers, e.g., "0,0;1,1"')
flags.DEFINE_string('source_variances', '1,1', 'Source variances, e.g., "1,1;0.5,0.5"')
flags.DEFINE_string('source_weights', '1.0', 'Source weights, e.g., "0.5,0.5"')
flags.DEFINE_string('target_centers', '10,-10;10,10', 'Target centers, e.g., "10,-10;10,10"')
flags.DEFINE_string('target_variances', '0.01,0.01;1,1', 'Target variances, e.g., "0.01,0.01;1,1"')
flags.DEFINE_string('target_weights', '0.01,0.99', 'Target weights, e.g., "0.01,0.99"')
flags.DEFINE_float('reg', 0.1, 'Regularization parameter for unbalanced knopp')
flags.DEFINE_float('tau_b', float('inf'), 'KL weight of target marginal')
flags.DEFINE_bool('replace', True, 'Replace samples from OT plan (allowing duplicates)')
flags.DEFINE_bool('fixed_source', True, 'Fixed source samples')
flags.DEFINE_integer('visual_count', 480, 'Number of samples to visualize')
flags.DEFINE_string('data_mode', 'gaussian', 'Data mode: gaussian or squares')
flags.DEFINE_string('source_square', '0,0,2,2,256', 'Source square spec: "cx,cy,w,h,n"')
flags.DEFINE_string('target_squares', '10,-10,1,1,128;10,10,1,1,384', 'Target squares specs: ";"-separated list of "cx,cy,w,h,n"')
flags.DEFINE_bool('draw_square_outlines', False, 'Draw square outlines on scatter plots (squares mode only)')
flags.DEFINE_integer('traj_points', 10000, 'Number of source samples for trajectory plot (squares mode)')
flags.DEFINE_integer('vis_batch_size', 0, 'Visualization batch size (0: auto; gaussian uses batch_size, squares uses source n)')
flags.DEFINE_integer('traj_vis_count', 2000, 'Number of trajectories to visualize in compact plot')

# Add logging flags
flags.DEFINE_string('experiment_log_dir', 'logs', 'Directory to store log files')
flags.DEFINE_string('experiment_name', None, 'Name for the experiment (used in log filename)')
flags.DEFINE_string('log_level', 'INFO', 'Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)')

FLAGS = flags.FLAGS

def main(argv):
    del argv # Unused

    # Set up logging
    logger = setup_logging(
        log_dir=FLAGS.experiment_log_dir,
        log_level=FLAGS.log_level,
        experiment_name=FLAGS.experiment_name,
        console_output=True
    )
    
    logger.info("Starting majority_test experiment")
    logger.info(f"Parameters: plan={FLAGS.plan}, reg={FLAGS.reg}, tau_b={FLAGS.tau_b}, "
                f"replace={FLAGS.replace}, fixed_source={FLAGS.fixed_source}, data_mode={FLAGS.data_mode}")

    # Parse complex arguments
    source_centers = parse_list_of_tuples(FLAGS.source_centers)
    source_variances = parse_list_of_tuples(FLAGS.source_variances)
    source_weights = parse_list_of_floats(FLAGS.source_weights)
    target_centers = parse_list_of_tuples(FLAGS.target_centers)
    target_variances = parse_list_of_tuples(FLAGS.target_variances)
    target_weights = parse_list_of_floats(FLAGS.target_weights)

    # Set up save directory
    savedir = f"results/{FLAGS.data_mode}/"
    os.makedirs(savedir, exist_ok=True)

    # Use flags for parameters
    plan = FLAGS.plan
    weight_type = FLAGS.weight_type
    epochs = FLAGS.epochs
    visual_interval = FLAGS.visual_interval
    batch_size = FLAGS.batch_size
    weight_power_factor = FLAGS.weight_power_factor

    data_mode = FLAGS.data_mode.lower()
    if data_mode not in ["gaussian", "squares"]:
        raise ValueError("data_mode must be either 'gaussian' or 'squares'")

    # Parse square specs if needed
    if data_mode == 'squares':
        source_square_spec = parse_square(FLAGS.source_square)
        target_squares_specs = parse_square_list(FLAGS.target_squares)
        if len(target_squares_specs) == 0:
            raise ValueError("At least one target square must be provided in squares mode")
        n0 = int(source_square_spec[4])
        n1 = int(sum(spec[4] for spec in target_squares_specs))
        if plan in ['icfm', 'exact'] and n0 != n1:
            raise ValueError(f"In '{plan}' plan with squares mode, source n ({n0}) must equal total target n ({n1})")

    if plan == 'icfm':
        ot_sampler = None
    elif plan == 'exact':
        ot_sampler = OTPlanSampler(method="exact")
    elif plan == 'uot_fm':
        ot_sampler = OTPlanSampler(method="sinkhorn", reg=FLAGS.reg)
    elif plan == 'uot_wfm':
        ot_sampler = OTPlanSampler(method="unbalanced_knopp",reg=FLAGS.reg, reg_m=(float("inf"), float(FLAGS.tau_b))) # float("inf") can be used
    print(f"plan: {plan}, weight_type: {weight_type}, reg: {FLAGS.reg}, tau_b: {FLAGS.tau_b}, data_mode: {data_mode}")
    logger.info(f"OT sampler initialized: plan={plan}, reg={FLAGS.reg}, tau_b={FLAGS.tau_b}, data_mode={data_mode}")

    sigma = 0.1
    dim = 2
    model = MLP(dim=dim, time_varying=True)
    optimizer = torch.optim.Adam(model.parameters())
    FM = ConditionalFlowMatcher(sigma=sigma)

    start = time.time()
    for k in range(epochs):
        optimizer.zero_grad()

        if data_mode == 'gaussian':
            x0 = sample_unbalanced_kgaussians(batch_size, len(source_centers), source_centers, source_variances, source_weights)
            x1 = sample_unbalanced_kgaussians(batch_size, len(target_centers), target_centers, target_variances, target_weights)
        else:
            device = None
            dtype = None
            x0 = sample_uniform_square(*source_square_spec, device=device, dtype=dtype)
            x1 = sample_uniform_multi_squares(target_squares_specs, device=device, dtype=dtype)

        # Draw samples from OT plan
        if plan == 'icfm':
            x0_recoupled, x1_recoupled = x0, x1
        elif plan == 'exact':
            x0_recoupled, x1_recoupled = ot_sampler.sample_plan(x0, x1) # for exact OT
        elif plan == 'uot_fm':
            x0_recoupled, x1_recoupled, pi, u, v, i, j = ot_sampler.sample_plan_with_weights_and_indices(x0, x1, fixed_source=FLAGS.fixed_source, replace=FLAGS.replace) # for unbalanced OT
        elif plan == 'uot_wfm':
            try:
                x0_recoupled, x1_recoupled, pi, u, v, i, j = ot_sampler.sample_plan_with_weights_and_indices(x0, x1, fixed_source=FLAGS.fixed_source, replace=FLAGS.replace) # for unbalanced OT
            except RuntimeError as e:
                logger.error(f"Numerical error in OT computation at epoch {k}: {e}")
                logger.error(f"Parameters: reg={FLAGS.reg}, tau_b={FLAGS.tau_b}, batch_size={batch_size}")
                raise e

        # WEIGHT
        if weight_type == "inv_tnu" and plan == 'uot_wfm': # UOT-WFM
            tnu = pi.sum(dim=0)
            tnu = tnu.reshape(tnu.size(0), 1)
            tnu = tnu / (1/x1.size(0)) # normalizaed by batch size
            fm_weight = 1 / tnu.detach() # inverse weight (minority)
            fm_weight = fm_weight[j]
            fm_weight = fm_weight ** weight_power_factor
        elif weight_type == "inv_tnu" and plan != 'uot_wfm':
            raise ValueError("inv_tnu is only supported for uot_wfm")
        else:
            fm_weight = 1.0

        t = torch.rand(x0_recoupled.shape[0]).type_as(x0_recoupled)
        xt = sample_conditional_pt(x0_recoupled, x1_recoupled, t, sigma=0.01)
        ut = compute_conditional_vector_field(x0_recoupled, x1_recoupled)

        vt = model(torch.cat([xt, t[:, None]], dim=-1))
        loss = torch.mean(((vt - ut) ** 2) * fm_weight)

        loss.backward()
        optimizer.step()

        if (k + 1) % visual_interval == 0:
            end = time.time()
            print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
            start = end
            node = NeuralODE(
                torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
            with torch.no_grad():
                if data_mode == 'gaussian':
                    src_for_traj = sample_unbalanced_kgaussians(batch_size*100, len(source_centers), source_centers, source_variances, source_weights)
                else:
                    scx, scy, sw, sh, _ = source_square_spec
                    src_for_traj = sample_uniform_square(scx, scy, sw, sh, FLAGS.traj_points)

                traj = node.trajectory(
                    src_for_traj,
                    t_span=torch.linspace(0, 1, 100),
                )
                
                # Visualization uses a fresh batch independent from training batch
                if data_mode == 'gaussian':
                    vis_bs = FLAGS.vis_batch_size if FLAGS.vis_batch_size and FLAGS.vis_batch_size > 0 else batch_size
                    x0_vis = sample_unbalanced_kgaussians(vis_bs, len(source_centers), source_centers, source_variances, source_weights)
                    x1_vis = sample_unbalanced_kgaussians(vis_bs, len(target_centers), target_centers, target_variances, target_weights)
                else:
                    vis_bs = FLAGS.vis_batch_size if FLAGS.vis_batch_size and FLAGS.vis_batch_size > 0 else int(source_square_spec[4])
                    # allocate target counts proportionally to training counts
                    tgt_counts = np.array([int(spec[4]) for spec in target_squares_specs], dtype=float)
                    total_tgt = tgt_counts.sum()
                    if total_tgt == 0:
                        proportions = np.ones_like(tgt_counts) / len(tgt_counts)
                    else:
                        proportions = tgt_counts / total_tgt
                    raw = proportions * vis_bs
                    base = np.floor(raw).astype(int)
                    remainder = vis_bs - int(base.sum())
                    if remainder > 0:
                        frac = raw - base
                        order = np.argsort(-frac)
                        for idx in order[:remainder]:
                            base[idx] += 1
                    # build vis specs
                    vis_target_specs = []
                    for (spec, nvis) in zip(target_squares_specs, base.tolist()):
                        cx, cy, w, h, _ = spec
                        if nvis > 0:
                            vis_target_specs.append((cx, cy, w, h, int(nvis)))
                    # ensure at least 1 sample if vis_bs > 0 but rounding produced zeros
                    if len(vis_target_specs) == 0 and vis_bs > 0:
                        cx, cy, w, h, _ = target_squares_specs[0]
                        vis_target_specs = [(cx, cy, w, h, int(vis_bs))]
                    x0_vis = sample_uniform_square(source_square_spec[0], source_square_spec[1], source_square_spec[2], source_square_spec[3], int(vis_bs))
                    x1_vis = sample_uniform_multi_squares(vis_target_specs)
                    # for exact/icfm, enforce equal counts
                    if plan in ['icfm', 'exact'] and x0_vis.shape[0] != x1_vis.shape[0]:
                        # adjust target by resampling last spec to match source
                        diff = x0_vis.shape[0] - x1_vis.shape[0]
                        if diff > 0 and len(vis_target_specs) > 0:
                            cx, cy, w, h, nprev = vis_target_specs[-1]
                            vis_target_specs[-1] = (cx, cy, w, h, nprev + diff)
                            x1_vis = sample_uniform_multi_squares(vis_target_specs)
                        elif diff < 0:
                            # reduce target by slicing
                            x1_vis = x1_vis[:x0_vis.shape[0]]

                # Compute OT on visualization batch
                if plan == 'icfm':
                    x0_vis_recoupled, x1_vis_recoupled = x0_vis, x1_vis
                    fm_weight_vis = 1.0
                elif plan == 'exact':
                    x0_vis_recoupled, x1_vis_recoupled = ot_sampler.sample_plan(x0_vis, x1_vis)
                    fm_weight_vis = 1.0
                elif plan == 'uot_fm':
                    x0_vis_recoupled, x1_vis_recoupled, pi_vis, u_vis, v_vis, i_vis, j_vis = ot_sampler.sample_plan_with_weights_and_indices(x0_vis, x1_vis, fixed_source=FLAGS.fixed_source, replace=FLAGS.replace)
                    fm_weight_vis = 1.0
                elif plan == 'uot_wfm':
                    try:
                        x0_vis_recoupled, x1_vis_recoupled, pi_vis, u_vis, v_vis, i_vis, j_vis = ot_sampler.sample_plan_with_weights_and_indices(x0_vis, x1_vis, fixed_source=FLAGS.fixed_source, replace=FLAGS.replace)
                    except RuntimeError as e:
                        logger.error(f"Numerical error in OT computation (viz) at epoch {k}: {e}")
                        logger.error(f"Parameters: reg={FLAGS.reg}, tau_b={FLAGS.tau_b}, vis_batch_size={vis_bs}")
                        raise e
                    if weight_type == 'inv_tnu':
                        tnu_vis = pi_vis.sum(dim=0).reshape(-1, 1)
                        tnu_vis = tnu_vis / (1 / x1_vis.size(0))
                        fm_weight_vis = (1 / tnu_vis.detach())[j_vis]
                        fm_weight_vis = fm_weight_vis ** weight_power_factor
                    else:
                        fm_weight_vis = 1.0
                else:
                    # default fallback
                    x0_vis_recoupled, x1_vis_recoupled = x0_vis, x1_vis
                    fm_weight_vis = 1.0

                #  suffix   
                if plan in ['exact', 'icfm']:
                    suffix = f"{plan}"
                elif plan == 'uot_wfm':
                    suffix = f"{plan}_reg{FLAGS.reg}_taub{FLAGS.tau_b}_ps{FLAGS.weight_power_factor}"
                else:
                    suffix = f"{plan}_reg{FLAGS.reg}_taub{FLAGS.tau_b}"

                #plot_sample_points(x0_vis, x1_vis, x0_vis_recoupled, x1_vis_recoupled)
                plot_sample_points_weightvis(
                    x0_vis, x1_vis, x0_vis_recoupled, x1_vis_recoupled, fm_weight_vis,
                    visual_count=FLAGS.visual_count,
                    square_mode=(data_mode == 'squares'),
                    source_square=(source_square_spec[:4] if data_mode == 'squares' else None),
                    target_squares=([(cx, cy, w, h, n) for (cx, cy, w, h, n) in target_squares_specs] if data_mode == 'squares' else None),
                    draw_square_outlines=(FLAGS.draw_square_outlines if data_mode == 'squares' else False),
                    savepath=os.path.join(savedir, f"sample_points_{suffix}.png"),
                    save_bbox_inches='tight',
                    save_pad_inches=0.1,
                )

                if data_mode == 'gaussian':
                    centers_for_names = target_centers
                else:
                    centers_for_names = [(cx, cy) for (cx, cy, w, h, n) in target_squares_specs]
                target_names = [f"({c[0]}, {c[1]})" for c in centers_for_names]
                #plot_trajectories(traj.cpu().numpy(), vis_grid=True)
                traj_np = traj.cpu().numpy()
                plot_trajectories_with_ratios(traj_np, vis_grid=True, centers=centers_for_names, names=target_names)
                if plan in ['exact', 'icfm']:
                    suffix = f"{plan}"
                elif plan == 'uot_wfm':
                    suffix = f"{plan}_reg{FLAGS.reg}_taub{FLAGS.tau_b}_ps{FLAGS.weight_power_factor}"
                else:
                    suffix = f"{plan}_reg{FLAGS.reg}_taub{FLAGS.tau_b}"
                plt.savefig(os.path.join(savedir, f"traj_{suffix}.png"), bbox_inches='tight', pad_inches=0.1)
                plt.close()

                # Additional compact trajectory visualization with fewer samples (subset over sample axis)
                n_samples = traj_np.shape[1]
                n_small = max(1, n_samples // 20)
                if n_small > 0 and n_samples > 0:
                    idx = np.random.choice(n_samples, n_small, replace=False)
                    traj_small_np = traj_np[:, idx, :]
                    plot_trajectories_with_ratios(traj_small_np, vis_grid=True, centers=centers_for_names, names=target_names)
                    plt.savefig(os.path.join(savedir, f"traj_small_{suffix}.png"), bbox_inches='tight', pad_inches=0.1)
                    plt.close()
                n_small = max(1, n_samples // 40)
                if n_small > 0 and n_samples > 0:
                    idx = np.random.choice(n_samples, n_small, replace=False)
                    traj_small_np = traj_np[:, idx, :]
                    plot_trajectories_with_ratios(traj_small_np, vis_grid=True, centers=centers_for_names, names=target_names)
                    plt.savefig(os.path.join(savedir, f"traj_tiny_{suffix}.png"), bbox_inches='tight', pad_inches=0.1)
                    plt.close()

if __name__ == '__main__':
    app.run(main)