#!/usr/bin/env python3
import json
import math
import os
import sys
import time
from copy import deepcopy
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons

from models.conditional_flow_matching import *
from models.fm_models import *
from utils import *

def oat_loss_compute(xt, ut, v0, v1, vt, alpha=0.8):
    """
    Compute OAT loss.
    Args:
        xt: Current position at time t
        ut: Target velocity (x1 - x0)  
        v0: Initial velocity from model_old
        v1: Final velocity from model_old
        vt: Predicted velocity from current model
        alpha: Weighting parameter (default: 0.8)
    """
    # Curvature terms - measure deviation from straight path
    L_C1 = torch.mean(torch.norm((v0 + vt) / 2 - ut, dim=-1) ** 2)
    L_C2 = torch.mean(torch.norm((vt + v1) / 2 - ut, dim=-1) ** 2)
    
    # Impulse terms - measure velocity consistency  
    L_I1 = torch.mean(torch.norm(vt - v0, dim=-1) ** 2)
    L_I2 = torch.mean(torch.norm(v1 - vt, dim=-1) ** 2)
    
    return alpha * (L_C1 + L_C2) + (1 - alpha) * (L_I1 + L_I2)
    
def init_logger(log_dir, method_name):
    os.makedirs(log_dir, exist_ok=True)
    
    json_log_path = os.path.join(log_dir, f"{method_name}_training_log.json")
    txt_log_path = os.path.join(log_dir, f"{method_name}_training_log.txt")
    
    log_data = {
        "method": method_name,
        "start_time": datetime.now().isoformat(),
        "training_records": []
    }
    
    with open(json_log_path, 'w') as f:
        json.dump(log_data, f, indent=2)
    
    with open(txt_log_path, 'w') as f:
        f.write(f"Training Log for {method_name}\n")
        f.write(f"Start Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write("="*50 + "\n")
    
    return json_log_path, txt_log_path

def log_training_step(json_log_path, txt_log_path, step, loss, elapsed_time, extra_info=None):
    log_entry = {
        "step": step,
        "loss": float(loss),
        "elapsed_time": float(elapsed_time),
        "timestamp": datetime.now().isoformat()
    }
    
    if extra_info:
        log_entry.update(extra_info)
    
    with open(json_log_path, 'r') as f:
        log_data = json.load(f)
    
    log_data["training_records"].append(log_entry)
    
    with open(json_log_path, 'w') as f:
        json.dump(log_data, f, indent=2)
    
    with open(txt_log_path, 'a') as f:
        log_line = f"Step {step}: loss {loss:.6f}, time {elapsed_time:.2f}s"
        if extra_info:
            extra_str = ", ".join([f"{k} {v}" for k, v in extra_info.items()])
            log_line += f", {extra_str}"
        f.write(log_line + "\n")

def main():
    
    base_savedir = "ckpts/8gs-moons"
    os.makedirs(base_savedir, exist_ok=True)
    
    cfm_ckpt_dir = os.path.join(base_savedir, "cfm", "checkpoints")
    cfm_vis_dir = os.path.join(base_savedir, "cfm", "visualizations")
    cfm_log_dir = os.path.join(base_savedir, "cfm", "logs")
    otcfm_ckpt_dir = os.path.join(base_savedir, "otcfm", "checkpoints")
    otcfm_vis_dir = os.path.join(base_savedir, "otcfm", "visualizations")
    otcfm_log_dir = os.path.join(base_savedir, "otcfm", "logs")
    targetcfm_ckpt_dir = os.path.join(base_savedir, "targetcfm", "checkpoints")
    targetcfm_vis_dir = os.path.join(base_savedir, "targetcfm", "visualizations")
    targetcfm_log_dir = os.path.join(base_savedir, "targetcfm", "logs")
    sbcfm_ckpt_dir = os.path.join(base_savedir, "sbcfm", "checkpoints")
    sbcfm_vis_dir = os.path.join(base_savedir, "sbcfm", "visualizations")
    sbcfm_log_dir = os.path.join(base_savedir, "sbcfm", "logs")
    vpcfm_ckpt_dir = os.path.join(base_savedir, "vpcfm", "checkpoints")
    vpcfm_vis_dir = os.path.join(base_savedir, "vpcfm", "visualizations")
    vpcfm_log_dir = os.path.join(base_savedir, "vpcfm", "logs")
    
    oatfm_base_dir = os.path.join(base_savedir, "oatfm")
    cfm_oat_ckpt_dir = os.path.join(oatfm_base_dir, "cfm_oat", "checkpoints")
    cfm_oat_vis_dir = os.path.join(oatfm_base_dir, "cfm_oat", "visualizations")
    cfm_oat_log_dir = os.path.join(oatfm_base_dir, "cfm_oat", "logs")
    otcfm_oat_ckpt_dir = os.path.join(oatfm_base_dir, "otcfm_oat", "checkpoints")
    otcfm_oat_vis_dir = os.path.join(oatfm_base_dir, "otcfm_oat", "visualizations")
    otcfm_oat_log_dir = os.path.join(oatfm_base_dir, "otcfm_oat", "logs")
    targetcfm_oat_ckpt_dir = os.path.join(oatfm_base_dir, "targetcfm_oat", "checkpoints")
    targetcfm_oat_vis_dir = os.path.join(oatfm_base_dir, "targetcfm_oat", "visualizations")
    targetcfm_oat_log_dir = os.path.join(oatfm_base_dir, "targetcfm_oat", "logs")
    sbcfm_oat_ckpt_dir = os.path.join(oatfm_base_dir, "sbcfm_oat", "checkpoints")
    sbcfm_oat_vis_dir = os.path.join(oatfm_base_dir, "sbcfm_oat", "visualizations")
    sbcfm_oat_log_dir = os.path.join(oatfm_base_dir, "sbcfm_oat", "logs")
    vpcfm_oat_ckpt_dir = os.path.join(oatfm_base_dir, "vpcfm_oat", "checkpoints")
    vpcfm_oat_vis_dir = os.path.join(oatfm_base_dir, "vpcfm_oat", "visualizations")
    vpcfm_oat_log_dir = os.path.join(oatfm_base_dir, "vpcfm_oat", "logs")
    
    os.makedirs(cfm_ckpt_dir, exist_ok=True)
    os.makedirs(cfm_vis_dir, exist_ok=True)
    os.makedirs(cfm_log_dir, exist_ok=True)
    os.makedirs(otcfm_ckpt_dir, exist_ok=True)
    os.makedirs(otcfm_vis_dir, exist_ok=True)
    os.makedirs(otcfm_log_dir, exist_ok=True)
    os.makedirs(targetcfm_ckpt_dir, exist_ok=True)
    os.makedirs(targetcfm_vis_dir, exist_ok=True)
    os.makedirs(targetcfm_log_dir, exist_ok=True)
    os.makedirs(sbcfm_ckpt_dir, exist_ok=True)
    os.makedirs(sbcfm_vis_dir, exist_ok=True)
    os.makedirs(sbcfm_log_dir, exist_ok=True)
    os.makedirs(vpcfm_ckpt_dir, exist_ok=True)
    os.makedirs(vpcfm_vis_dir, exist_ok=True)
    os.makedirs(vpcfm_log_dir, exist_ok=True)
    
    os.makedirs(cfm_oat_ckpt_dir, exist_ok=True)
    os.makedirs(cfm_oat_vis_dir, exist_ok=True)
    os.makedirs(cfm_oat_log_dir, exist_ok=True)
    os.makedirs(otcfm_oat_ckpt_dir, exist_ok=True)
    os.makedirs(otcfm_oat_vis_dir, exist_ok=True)
    os.makedirs(otcfm_oat_log_dir, exist_ok=True)
    os.makedirs(targetcfm_oat_ckpt_dir, exist_ok=True)
    os.makedirs(targetcfm_oat_vis_dir, exist_ok=True)
    os.makedirs(targetcfm_oat_log_dir, exist_ok=True)
    os.makedirs(sbcfm_oat_ckpt_dir, exist_ok=True)
    os.makedirs(sbcfm_oat_vis_dir, exist_ok=True)
    os.makedirs(sbcfm_oat_log_dir, exist_ok=True)
    os.makedirs(vpcfm_oat_ckpt_dir, exist_ok=True)
    os.makedirs(vpcfm_oat_vis_dir, exist_ok=True)
    os.makedirs(vpcfm_oat_log_dir, exist_ok=True)
    
    print("Starting Conditional Flow Matching experiments...")
    
    # 1. Conditional Flow Matching
    print("\n" + "="*50)
    print("1. Conditional Flow Matching")
    print("="*50)

    sigma = 0.1
    dim = 2
    batch_size = 256
    model = MLP(dim=dim, time_varying=True)
    optimizer = torch.optim.Adam(model.parameters())
    FM = ConditionalFlowMatcher(sigma=sigma)
    
    cfm_json_log, cfm_txt_log = init_logger(cfm_log_dir, "cfm")

    start = time.time()
    for k in range(20000):
        optimizer.zero_grad()

        x0 = sample_8gaussians(batch_size)
        x1 = sample_moons(batch_size)

        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

        vt = model(torch.cat([xt, t[:, None]], dim=-1))
        loss = torch.mean((vt - ut) ** 2)

        loss.backward()
        optimizer.step()

        if (k + 1) % 5000 == 0:
            end = time.time()
            elapsed_time = end - start
            print(f"{k+1}: loss {loss.item():0.3f} time {elapsed_time:0.2f}")
            
            log_training_step(cfm_json_log, cfm_txt_log, k+1, loss.item(), elapsed_time, 
                            {"visualization": True})
            
            start = end

            node = NeuralODE(
                torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
            with torch.no_grad():
                traj = node.trajectory(
                    sample_8gaussians(1024),
                    t_span=torch.linspace(0, 1, 100),
                )
                plot_trajectories(traj.cpu().numpy(), method="cfm", step=k+1, save_dir=cfm_vis_dir)
            
    torch.save(model, os.path.join(cfm_ckpt_dir, "cfm_v1.pt"))
    print("CFM model saved successfully!")

    # 2. Optimal Transport Conditional Flow Matching
    print("\n" + "="*50)
    print("2. Optimal Transport Conditional Flow Matching")
    print("="*50)
    
    sigma = 0.1
    dim = 2
    batch_size = 256
    model = MLP(dim=dim, time_varying=True)
    optimizer = torch.optim.Adam(model.parameters())
    FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
    
    otcfm_json_log, otcfm_txt_log = init_logger(otcfm_log_dir, "otcfm")

    start = time.time()
    for k in range(20000):
        optimizer.zero_grad()

        x0 = sample_8gaussians(batch_size)
        x1 = sample_moons(batch_size)

        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

        vt = model(torch.cat([xt, t[:, None]], dim=-1))
        loss = torch.mean((vt - ut) ** 2)

        loss.backward()
        optimizer.step()

        if (k + 1) % 5000 == 0:
            end = time.time()
            elapsed_time = end - start
            print(f"{k+1}: loss {loss.item():0.3f} time {elapsed_time:0.2f}")
            
            log_training_step(otcfm_json_log, otcfm_txt_log, k+1, loss.item(), elapsed_time, 
                            {"visualization": True})
            
            start = end

            node = NeuralODE(
                torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
            with torch.no_grad():
                traj = node.trajectory(
                    sample_8gaussians(1024),
                    t_span=torch.linspace(0, 1, 100),
                )
                plot_trajectories(traj.cpu().numpy(), method="otcfm", step=k+1, save_dir=otcfm_vis_dir)
            
    torch.save(model, os.path.join(otcfm_ckpt_dir, "otcfm_v1.pt"))
    print("OT-CFM model saved successfully!")

    # 3. Target Conditional Flow Matching
    print("\n" + "="*50)
    print("3. Target Conditional Flow Matching ([anonymous])")
    print("="*50)
    
    sigma = 0.1
    dim = 2
    batch_size = 256
    model = MLP(dim=dim, time_varying=True)
    optimizer = torch.optim.Adam(model.parameters())
    FM = TargetConditionalFlowMatcher(sigma=sigma)
    
    targetcfm_json_log, targetcfm_txt_log = init_logger(targetcfm_log_dir, "targetcfm")

    start = time.time()
    for k in range(20000):
        optimizer.zero_grad()

        x0 = sample_8gaussians(batch_size)
        x1 = sample_moons(batch_size)

        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

        vt = model(torch.cat([xt, t[:, None]], dim=-1))
        loss = torch.mean((vt - ut) ** 2)

        loss.backward()
        optimizer.step()

        if (k + 1) % 5000 == 0:
            end = time.time()
            elapsed_time = end - start
            print(f"{k+1}: loss {loss.item():0.3f} time {elapsed_time:0.2f}")
            
            log_training_step(targetcfm_json_log, targetcfm_txt_log, k+1, loss.item(), elapsed_time, 
                            {"visualization": True})
            
            start = end

            node = NeuralODE(
                torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
            with torch.no_grad():
                traj = node.trajectory(
                    sample_8gaussians(1024),
                    t_span=torch.linspace(0, 1, 100),
                )
                plot_trajectories(traj.cpu().numpy(), method="targetcfm", step=k+1, save_dir=targetcfm_vis_dir)
            
    torch.save(model, os.path.join(targetcfm_ckpt_dir, "targetcfm_v1.pt"))
    print("Target CFM model saved successfully!")

    # 4. Schrodinger Bridge Conditional Flow Matching
    print("\n" + "="*50)
    print("4. Schrodinger Bridge Conditional Flow Matching")
    print("="*50)
    
    sigma = 0.5
    dim = 2
    batch_size = 256
    model = MLP(dim=dim, time_varying=True)
    optimizer = torch.optim.Adam(model.parameters())
    FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, ot_method="exact")
    
    sbcfm_json_log, sbcfm_txt_log = init_logger(sbcfm_log_dir, "sbcfm")

    start = time.time()
    for k in range(20000):
        optimizer.zero_grad()

        x0 = sample_8gaussians(batch_size)
        x1 = sample_moons(batch_size)

        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

        vt = model(torch.cat([xt, t[:, None]], dim=-1))
        loss = torch.mean((vt - ut) ** 2)

        loss.backward()
        optimizer.step()

        if (k + 1) % 5000 == 0:
            end = time.time()
            elapsed_time = end - start
            print(f"{k+1}: loss {loss.item():0.3f} time {elapsed_time:0.2f}")
            
            log_training_step(sbcfm_json_log, sbcfm_txt_log, k+1, loss.item(), elapsed_time, 
                            {"visualization": True})
            
            start = end

            node = NeuralODE(
                torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
            with torch.no_grad():
                traj = node.trajectory(
                    sample_8gaussians(1024),
                    t_span=torch.linspace(0, 1, 100),
                )
                plot_trajectories(traj.cpu().numpy(), method="sbcfm", step=k+1, save_dir=sbcfm_vis_dir)
            
    torch.save(model, os.path.join(sbcfm_ckpt_dir, "sbcfm_v1.pt"))
    print("SB-CFM model saved successfully!")

    # 5. Variance Preserving Conditional Flow Matching (Stochastic Interpolants)
    print("\n" + "="*50)
    print("5. Variance Preserving Conditional Flow Matching (Stochastic Interpolants)")
    print("="*50)
    
    sigma = 0.1
    dim = 2
    batch_size = 256
    model = MLP(dim=dim, time_varying=True)
    optimizer = torch.optim.Adam(model.parameters())
    FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
    
    vpcfm_json_log, vpcfm_txt_log = init_logger(vpcfm_log_dir, "vpcfm")

    start = time.time()
    for k in range(20000):
        optimizer.zero_grad()

        x0 = sample_8gaussians(batch_size)
        x1 = sample_moons(batch_size)

        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

        vt = model(torch.cat([xt, t[:, None]], dim=-1))
        loss = torch.mean((vt - ut) ** 2)

        loss.backward()
        optimizer.step()

        if (k + 1) % 5000 == 0:
            end = time.time()
            elapsed_time = end - start
            print(f"{k+1}: loss {loss.item():0.3f} time {elapsed_time:0.2f}")
            
            log_training_step(vpcfm_json_log, vpcfm_txt_log, k+1, loss.item(), elapsed_time, 
                            {"visualization": True})
            
            start = end

            node = NeuralODE(
                torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
            with torch.no_grad():
                traj = node.trajectory(
                    sample_8gaussians(1024),
                    t_span=torch.linspace(0, 1, 100),
                )
                plot_trajectories(traj.cpu().numpy(), method="vpcfm", step=k+1, save_dir=vpcfm_vis_dir)
            
    torch.save(model, os.path.join(vpcfm_ckpt_dir, "vpcfm_v1.pt"))
    print("VP-CFM model saved successfully!")

    # 6. OAT-FM Refinement Phase
    print("\n" + "="*60)
    print("OAT-FM REFINEMENT PHASE - Refining all base methods")
    print("="*60)
    
    # OAT-FM Configuration
    sigma = 0.1
    dim = 2
    batch_size = 256
    hardcopy_update_freq = 500
    visualization_freq = 5000
    oatfm_steps = 20000
    
    base_methods = [
        {
            "name": "CFM",
            "model_path": os.path.join(cfm_ckpt_dir, "cfm_v1.pt"),
            "ckpt_dir": cfm_oat_ckpt_dir,
            "vis_dir": cfm_oat_vis_dir,
            "log_dir": cfm_oat_log_dir,
            "method_tag": "cfm_oat"
        },
        {
            "name": "OT-CFM",
            "model_path": os.path.join(otcfm_ckpt_dir, "otcfm_v1.pt"),
            "ckpt_dir": otcfm_oat_ckpt_dir,
            "vis_dir": otcfm_oat_vis_dir,
            "log_dir": otcfm_oat_log_dir,
            "method_tag": "otcfm_oat"
        },
        {
            "name": "Target CFM",
            "model_path": os.path.join(targetcfm_ckpt_dir, "targetcfm_v1.pt"),
            "ckpt_dir": targetcfm_oat_ckpt_dir,
            "vis_dir": targetcfm_oat_vis_dir,
            "log_dir": targetcfm_oat_log_dir,
            "method_tag": "targetcfm_oat"
        },
        {
            "name": "SB-CFM",
            "model_path": os.path.join(sbcfm_ckpt_dir, "sbcfm_v1.pt"),
            "ckpt_dir": sbcfm_oat_ckpt_dir,
            "vis_dir": sbcfm_oat_vis_dir,
            "log_dir": sbcfm_oat_log_dir,
            "method_tag": "sbcfm_oat"
        },
        {
            "name": "VP-CFM",
            "model_path": os.path.join(vpcfm_ckpt_dir, "vpcfm_v1.pt"),
            "ckpt_dir": vpcfm_oat_ckpt_dir,
            "vis_dir": vpcfm_oat_vis_dir,
            "log_dir": vpcfm_oat_log_dir,
            "method_tag": "vpcfm_oat"
        }
    ]
    
    for i, method_config in enumerate(base_methods, 1):        
        try:
            model_old = torch.load(method_config["model_path"], weights_only=False)
            model = deepcopy(model_old)
            print(f"Successfully loaded {method_config['name']} model")
        except Exception as e:
            print(f"Failed to load {method_config['name']} model: {e}")
            continue
            
        optimizer = torch.optim.Adam(model.parameters())
        FM = OATConditionalFlowMatcher(sigma=sigma)
        
        oat_json_log, oat_txt_log = init_logger(method_config["log_dir"], method_config["method_tag"])

        start = time.time()
        for k in range(oatfm_steps):
            optimizer.zero_grad()

            x0 = sample_8gaussians(batch_size)
            x1 = sample_moons(batch_size)
            
            v0_init = model_old(torch.cat([x0, torch.zeros(x0.shape[0], 1)], dim=-1))
            v1_init = model_old(torch.cat([x1, torch.ones(x1.shape[0], 1)], dim=-1))

            z0 = (x0, v0_init)
            z1 = (x1, v1_init)

            t, xt, ut, v0_sampled, v1_sampled = FM.sample_location_and_conditional_flow(z0, z1)

            vt = model(torch.cat([xt, t[:, None]], dim=-1))
             
            loss = oat_loss_compute(xt, ut, v0_sampled, v1_sampled, vt)

            loss.backward()
            optimizer.step()

            if (k + 1) % hardcopy_update_freq == 0:
                model_old = deepcopy(model)
                
                log_training_step(oat_json_log, oat_txt_log, k+1, loss.item(), 0, 
                                {"hardcopy_update": True})

            if (k + 1) % visualization_freq == 0:
                end = time.time()
                elapsed_time = end - start
                print(f"  {k+1}: loss {loss.item():0.3f} time {elapsed_time:0.2f}")
                
                log_training_step(oat_json_log, oat_txt_log, k+1, loss.item(), elapsed_time, 
                                {"visualization": True, "checkpoint_saved": True})
                
                start = end

                node = NeuralODE(
                    torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
                )
                with torch.no_grad():
                    traj = node.trajectory(
                        sample_8gaussians(1024),
                        t_span=torch.linspace(0, 1, 100),
                    )
                    plot_trajectories(traj.cpu().numpy(), method=method_config["method_tag"], 
                                    step=k+1, save_dir=method_config["vis_dir"])
                torch.save(model, os.path.join(method_config["ckpt_dir"], 
                                             f"{method_config['method_tag']}_step{k+1}.pt"))
        
        print(f"{method_config['name']} OAT-FM refinement completed!")
    
    
if __name__ == "__main__":
    main()
