############################################################
# Inspired from https://github.com/w86763777/pytorch-ddpm
# DDP version aligned with customized train_cifar10.py
############################################################

import copy
import os
from datetime import datetime

import torch
torch.backends.cuda.matmul.allow_tf32 = True  # TF32 for faster training
torch.backends.cudnn.allow_tf32 = True        # TF32 for faster training

from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
from torchvision import datasets, transforms
from tqdm import trange

from utils_cifar import (
    ema,
    generate_samples,
    infiniteloop,
    plot_fm_weights_histogram,
    save_value_to_txt,
    setup,
)

import sys
import os as _os
# add repo root to sys.path like train_cifar10.py
current_dir = _os.path.dirname(_os.path.abspath(__file__))
root_dir = _os.path.join(current_dir, "../../../")
sys.path.insert(0, _os.path.abspath(root_dir))

from torchcfm.utils import (
    compute_cifar10_mean_std,
    energy_weight,
    CIFAR10LTDataset_regacy,
    ImbalanceCIFAR10,
    ImbalanceCIFAR100,
    exp_naming,
)
from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    SinkhornOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS

# Core options (aligned with train_cifar10.py)
flags.DEFINE_string("model", "icfm", help="flow matching model type[otcfm, sinkhorn_otcfm, sinkhorn_otwfm, icfm, itfm, si]")
flags.DEFINE_string("output_dir", "./results/", help="output_directory")
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_integer("resume_step", 0, help="resume from step, 0 to start from scratch")
flags.DEFINE_float("lr", 2e-4, help="target learning rate")
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer("total_steps", 400001, help="total training steps")
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="global batch size")
flags.DEFINE_integer("num_workers", 32, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", True, help="enable DistributedDataParallel")

# Dataset / OT / weighting options
flags.DEFINE_string("dataset_name", "cifar10", help="dataset name [cifar10, cifar10_lt, cifar100_lt, cifar10_lt_regacy]")
flags.DEFINE_string("data_root", None, help="data root")
flags.DEFINE_string("data_norm", "default", help="['adaptive','default','cifar10','cifar100','cifar10_lt','cifar100_lt']")
flags.DEFINE_float("reg", 1.0, help="regularization parameter for Sinkhorn")
flags.DEFINE_float("tau_b", 1.0, help="regularization parameter b for Sinkhorn")
flags.DEFINE_string("method", "unbalanced", help="method for Sinkhorn: ['unbalanced_knopp', 'unbalanced']")
flags.DEFINE_bool("normalize_cost", True, help="normalize cost of OT")
flags.DEFINE_bool("recoupling", True, help="sinkhorn change the coupling of x0 and x1")
flags.DEFINE_bool("fixed_source", False, help="sinkhorn fixed source")
flags.DEFINE_bool("fixed_target", False, help="sinkhorn fixed target")
flags.DEFINE_bool("efm", False, help="energy-weighted flow matching")
flags.DEFINE_string("weight_type", "none", help="weight type for flow matching [none, inv_tnu]")
flags.DEFINE_float("weight_power_factor", 1.0, help="power factor for weight")
flags.DEFINE_float("beta", 1.0, help="beta for energy-weighted flow matching")
flags.DEFINE_bool("save_weights_plot", True, help="save weights plot")

# DDP control
flags.DEFINE_string("master_addr", "localhost", help="master address for DDP")
flags.DEFINE_string("master_port", "12355", help="master port for DDP")

# Evaluation
flags.DEFINE_integer("save_step", 200000, help="frequency of saving checkpoints, 0 to disable during training")


def warmup_lr(step: int) -> float:
    return min(step, FLAGS.warmup) / FLAGS.warmup


def train(rank, total_num_gpus: int, argv):
    is_master = (int(os.getenv("RANK", 0)) == 0) if (FLAGS.parallel and total_num_gpus > 1) else True

    use_cuda = torch.cuda.is_available()
    device = rank if (FLAGS.parallel and total_num_gpus > 1) else ("cuda" if use_cuda else "cpu")

    if FLAGS.parallel and total_num_gpus > 1:
        # batch size per GPU to keep global batch size fixed
        batch_size_per_gpu = max(1, FLAGS.batch_size // total_num_gpus)
        setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port)
    else:
        batch_size_per_gpu = FLAGS.batch_size

    # Data root default
    if FLAGS.data_root is None:
        if FLAGS.dataset_name in ["cifar10", "cifar10_lt", "cifar100_lt", "cifar10_lt_regacy"]:
            FLAGS.data_root = "./data"
        else:
            raise ValueError(
                f"Unknown dataset {FLAGS.dataset_name}, set --data_root explicitly"
            )

    # Data normalization
    if FLAGS.data_norm == "adaptive":
        mean, std = compute_cifar10_mean_std(FLAGS.data_root, train=True, batch_size=1024, num_workers=4)
    elif FLAGS.data_norm == "default":
        mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
    elif FLAGS.data_norm in ["cifar10", "cifar10_lt"]:
        mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    elif FLAGS.data_norm in ["cifar100", "cifar100_lt"]:
        mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
    else:
        raise ValueError(
            f"Unknown data normalization {FLAGS.data_norm}"
        )

    # Dataset & DataLoader
    if FLAGS.dataset_name == "cifar10":
        dataset = datasets.CIFAR10(
            root=FLAGS.data_root,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
        )
    elif FLAGS.dataset_name == "cifar10_lt_regacy":
        dataset = CIFAR10LTDataset_regacy(
            data_dir=FLAGS.data_root,
            split="train",
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
        )
    elif FLAGS.dataset_name == "cifar10_lt":
        dataset = ImbalanceCIFAR10(
            root=FLAGS.data_root,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
        )
    elif FLAGS.dataset_name == "cifar100_lt":
        dataset = ImbalanceCIFAR100(
            root=FLAGS.data_root,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
        )
    else:
        raise ValueError(
            f"Unknown dataset {FLAGS.dataset_name}, must be one of ['cifar10','cifar10_lt','cifar100_lt','cifar10_lt_regacy']"
        )

    sampler = DistributedSampler(dataset) if (FLAGS.parallel and total_num_gpus > 1) else None
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size_per_gpu,
        sampler=sampler,
        shuffle=False if sampler is not None else True,
        num_workers=FLAGS.num_workers,
        drop_last=True,
    )
    datalooper = infiniteloop(dataloader)

    # Model
    net_model = UNetModelWrapper(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=FLAGS.num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.1,
    ).to(device)

    ema_model = copy.deepcopy(net_model)
    optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)

    if FLAGS.parallel and total_num_gpus > 1:
        net_model = DistributedDataParallel(net_model, device_ids=[rank])
        ema_model = DistributedDataParallel(ema_model, device_ids=[rank])

    # Print model size (rank 0 only)
    if is_master:
        model_size = sum(p.data.nelement() for p in net_model.parameters())
        print("Model params: %.2f M" % (model_size / 1024 / 1024))

    # FM object
    sigma = 0.0
    if FLAGS.model == "otcfm":
        FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma, normalize_cost=FLAGS.normalize_cost)
    elif FLAGS.model in ["sinkhorn_otcfm", "sinkhorn_otwfm"]:
        FM = SinkhornOptimalTransportConditionalFlowMatcher(
            sigma=sigma,
            method=FLAGS.method,
            reg=FLAGS.reg,
            reg_m=(float("inf"), FLAGS.tau_b),
            normalize_cost=FLAGS.normalize_cost,
            recoupling=FLAGS.recoupling,
            fixed_source=FLAGS.fixed_source,
            fixed_target=FLAGS.fixed_target,
        )
    elif FLAGS.model in ["icfm", "fm_ot"]:
        FM = ConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model in ["itfm", "fm_dif"]:
        FM = TargetConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "si":
        FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
    else:
        raise NotImplementedError(
            f"Unknown model {FLAGS.model}"
        )

    # Save dir and resume
    savedir = os.path.join(FLAGS.output_dir, exp_naming(FLAGS))
    if is_master:
        os.makedirs(savedir, exist_ok=True)

    if FLAGS.resume_step > 0:
        ckpt_path = os.path.join(savedir, f"{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.resume_step}.pt")
        if not os.path.exists(ckpt_path) and is_master:
            raise ValueError(f"Checkpoint not found at {ckpt_path}")
        # All ranks try loading to keep in sync (shared filesystem assumed)
        checkpoint = torch.load(ckpt_path, map_location=device)
        net_model.load_state_dict(checkpoint["net_model"])
        ema_model.load_state_dict(checkpoint["ema_model"])
        optim.load_state_dict(checkpoint["optim"])
        sched.load_state_dict(checkpoint["sched"])
        start_step = int(checkpoint["step"]) + 1
        if is_master:
            print(f"Resuming from step {start_step}")
    else:
        start_step = 0

    # fm weight stats (rank 0 only)
    fm_weight_list = []
    fm_weight_min_list = []
    fm_weight_max_list = []
    fm_weight_mean_list = []
    fm_weight_std_list = []
    fm_weight_inv_std_list = []
    fm_weight_inv_mean_list = []

    # Make log file (rank 0 only)
    if is_master:
        start_time = datetime.now()
        with open(os.path.join(savedir, 'log.txt'), 'w') as f:
            f.write("===== Hyperparameters(FLAGS) =====\n")
            for name in FLAGS:
                f.write(f"{name}: {getattr(FLAGS, name)}\n")
            f.write("===== Additional Information =====\n")
            f.write(f"data is normalized to (0,I) from mean: {mean}, std: {std}\n")
            f.write("==================================\n")
            f.write(f"Start Training at {start_time}\n\n")

    # Training loop by steps (aligned with train_cifar10.py)
    with trange(start_step, FLAGS.total_steps, dynamic_ncols=True, disable=not is_master) as pbar:
        for step in pbar:
            if sampler is not None:
                sampler.set_epoch(step)

            optim.zero_grad()
            x1 = next(datalooper).to(device)
            x0 = torch.randn_like(x1)

            if FLAGS.model in ["icfm", "itfm", "si", "otcfm"]:
                t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
            elif "sinkhorn" in FLAGS.model:
                t, xt, ut, pi, w_u, w_v, i, j = FM.sample_location_and_conditional_flow(x0, x1)
            else:
                raise ValueError(
                    f"Unknown model {FLAGS.model}"
                )

            vt = net_model(t, xt)

            if FLAGS.weight_type == "none" or ("sinkhorn" not in FLAGS.model):
                fm_weight = 1.0
            else:
                if FLAGS.weight_type == "inv_tnu":
                    tnu = pi.sum(dim=0)
                    tnu = tnu.reshape(tnu.size(0), 1, 1, 1).to(device)
                    tnu = tnu / (1 / x1.size(0))
                    fm_weight = 1 / tnu
                    fm_weight = fm_weight[j]
                    fm_weight = fm_weight ** FLAGS.weight_power_factor
                else:
                    raise ValueError(f"Unknown weight type {FLAGS.weight_type}")
                if FLAGS.efm:
                    fm_weight = energy_weight(fm_weight, beta=FLAGS.beta)
                if FLAGS.save_weights_plot and is_master:
                    fm_weight_list.append(fm_weight)
                    fm_weight_min_list.append(fm_weight.min())
                    fm_weight_max_list.append(fm_weight.max())
                    fm_weight_mean_list.append(fm_weight.mean())
                    fm_weight_std_list.append(fm_weight.std())
                    fm_weight_inv_std_list.append((tnu).std())
                    fm_weight_inv_mean_list.append((tnu).mean())

            loss = torch.mean(((vt - ut) ** 2) * (fm_weight if isinstance(fm_weight, torch.Tensor) else 1.0))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip)
            optim.step()
            sched.step()
            ema(net_model, ema_model, FLAGS.ema_decay)

            # Sample & save (master only)
            if is_master and FLAGS.save_step > 0 and step % FLAGS.save_step == 0:
                generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal", device=device)
                generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema", device=device)
                torch.save(
                    {
                        "net_model": net_model.state_dict(),
                        "ema_model": ema_model.state_dict(),
                        "sched": sched.state_dict(),
                        "optim": optim.state_dict(),
                        "step": step,
                    },
                    os.path.join(savedir, f"{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{step}.pt"),
                )

    # Final weight stats & log (master only)
    if is_master and FLAGS.save_weights_plot and len(fm_weight_list) > 0:
        fm_weight_tensor = torch.cat(fm_weight_list, dim=0)
        plot_fm_weights_histogram(fm_weight_tensor.reshape(-1), FLAGS.weight_type, savedir, extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
        plot_fm_weights_histogram(torch.tensor(fm_weight_min_list), FLAGS.weight_type, savedir, data_type="min of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
        plot_fm_weights_histogram(torch.tensor(fm_weight_max_list), FLAGS.weight_type, savedir, data_type="max of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
        plot_fm_weights_histogram(torch.tensor(fm_weight_mean_list), FLAGS.weight_type, savedir, data_type="mean of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")

        avg_weight_std = torch.tensor(fm_weight_std_list).mean().item()
        avg_weight_mean = torch.tensor(fm_weight_mean_list).mean().item()
        avg_weight_inv_std = torch.tensor(fm_weight_inv_std_list).mean().item()
        avg_weight_inv_mean = torch.tensor(fm_weight_inv_mean_list).mean().item()

        with open(os.path.join(savedir, 'log.txt'), 'a') as f:
            f.write(
                f"step: {FLAGS.total_steps-1} avg_weight_std: {avg_weight_std:.6f} avg_weight_mean: {avg_weight_mean:.6f} avg_weight_inv_std: {avg_weight_inv_std:.6f} avg_weight_inv_mean: {avg_weight_inv_mean:.6f}\n"
            )

    if is_master:
        with open(os.path.join(savedir, 'log.txt'), 'a') as f:
            f.write(f"End Training at {datetime.now()}\n")


def main(argv):
    total_num_gpus = int(os.getenv("WORLD_SIZE", 1))
    if FLAGS.parallel and total_num_gpus > 1:
        train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv)
    else:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        train(rank=device, total_num_gpus=total_num_gpus, argv=argv)


if __name__ == "__main__":
    app.run(main)
