# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Anonymous

import copy
import os
from datetime import datetime

import torch
torch.backends.cuda.matmul.allow_tf32 = True # For faster training
torch.backends.cudnn.allow_tf32 = True # For faster training
import traceback

from absl import app, flags
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop, plot_fm_weights_histogram, save_value_to_txt, get_real_dataset

import sys
import os
# Move to repository root from this file's directory
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.join(current_dir, '../../../')
sys.path.insert(0, os.path.abspath(root_dir))

from torchcfm.utils import compute_dataset_mean_std
from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    SinkhornOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper
from torchcfm.utils import energy_weight, CIFAR10LTDataset_regacy, ImbalanceCIFAR10, ImbalanceCIFAR100, exp_naming

FLAGS = flags.FLAGS

flags.DEFINE_string("model", "icfm", help="flow matching model type[otcfm, sinkhorn_otwfm, sinkhorn_otcfm, sinkhorn_otwfm_dv, icfm, itfm, si]")
flags.DEFINE_string("output_dir", "./results/", help="output_directory")
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_integer("resume_step", 0, help="resume from step, 0 to start from scratch")
flags.DEFINE_float("lr", 2e-4, help="target learning rate")  # TRY 2e-4
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer("total_steps", 400001, help="total training steps") 

# Anonymous et al uses 400k 
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size")  # Anonymous et al uses 128
flags.DEFINE_integer("num_workers", 32, help="workers of Dataloader") # default 4
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string("dataset_name", "cifar10", help="dataset name")
flags.DEFINE_float("reg", 0.05, help="regularization parameter")
flags.DEFINE_float("tau_b", 1.0, help="regularization parameter b for Sinkhorn")
flags.DEFINE_string("method", "unbalanced", help="method for Sinkhorn: ['unbalanced_knopp', 'unbalanced']")
flags.DEFINE_float("beta", 1.0, help="beta for energy-weighted flow matching")
flags.DEFINE_bool("efm", False, help="energy-weighted flow matching")
flags.DEFINE_string("weight_type", "none", help="weight type for flow matching [inv_tnu]")
flags.DEFINE_float("weight_power_factor", 1.0, help="power factor for weight")
flags.DEFINE_bool("save_weights_plot", True, help="save weights plot")
flags.DEFINE_bool("normalize_cost", True, help="normalize cost of OT")
flags.DEFINE_bool("recoupling", True, help="sinkhorn change the coupling of x0 and x1")
flags.DEFINE_bool("fixed_source", False, help="sinkhorn fixed source")
flags.DEFINE_bool("fixed_target", False, help="sinkhorn fixed target")
flags.DEFINE_string("data_root", None, help="data root")
flags.DEFINE_string("data_norm", "default", help="choose data normalization: ['adaptive', 'default']")
# Evaluation
flags.DEFINE_integer(
    "save_step",
    200000,
    help="frequency of saving checkpoints, 0 to disable during training",
)
flags.DEFINE_string("device", "cuda:0", help="Device to use [cuda:0, cuda:1]")

flags.DEFINE_string("imb_type", "exp", help="imbalance type for long-tail dataset [exp, step]")
flags.DEFINE_float("imb_factor", 0.01, help="imbalance factor for long-tail dataset")
flags.DEFINE_integer("rand_number", 0, help="random seed for dataset generation")




def warmup_lr(step):
    return min(step, FLAGS.warmup) / FLAGS.warmup


def train(argv):
    use_cuda = torch.cuda.is_available()
    device = torch.device(FLAGS.device if use_cuda else "cpu")

    print(
        "lr, total_steps, ema decay, save_step:",
        FLAGS.lr,
        FLAGS.total_steps,
        FLAGS.ema_decay,
        FLAGS.save_step,
    )

    # Set dataset root location #
    if FLAGS.data_root is None:
        if FLAGS.dataset_name == "cifar10":
            FLAGS.data_root = "./data"
        elif FLAGS.dataset_name == "cifar100":
            FLAGS.data_root = "./data"
        elif FLAGS.dataset_name == "cifar10_lt":
            FLAGS.data_root = "./data"
        elif FLAGS.dataset_name == "cifar100_lt":
            FLAGS.data_root = "./data"
        else:
            raise ValueError(f"Unknown dataset {FLAGS.dataset_name}, must be one of ['cifar10', 'cifar100', 'cifar10_lt', 'cifar100_lt'] or set data_root manually")

    # Configure data normalization #
    if FLAGS.data_norm == "adaptive": # Automatically compute current distribution and normalize to zero-mean/unit-variance
        if FLAGS.dataset_name == "cifar10":
            dataset_class = datasets.CIFAR10
        elif FLAGS.dataset_name == "cifar100":
            dataset_class = datasets.CIFAR100
        elif FLAGS.dataset_name == "cifar10_lt":
            dataset_class = ImbalanceCIFAR10
        elif FLAGS.dataset_name == "cifar100_lt":
            dataset_class = ImbalanceCIFAR100
        else:
            raise ValueError(f"Unknown dataset {FLAGS.dataset_name}, must be one of ['cifar10', 'cifar100', 'cifar10_lt', 'cifar100_lt'] or set data_root manually")
        mean, std = compute_dataset_mean_std(dataset_class, FLAGS.data_root, train=True, batch_size=1024, num_workers=4) # IMBALANCED      .
        print(f"Adaptive normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "default": # Default setting used in prior experiments (torchcfm default)
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        print(f"Fixed default normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "cifar10": # to zero mean, unit variance
        mean = (0.4914, 0.4822, 0.4465) # verified
        std = (0.2470, 0.2435, 0.2616) # verified
        print(f"Fixed CIFAR10 normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "cifar100": # to zero mean, unit variance
        mean = (0.5071, 0.48651, 0.44091) # verified
        std = (0.2673, 0.2564, 0.2762) # verified
        print(f"Fixed CIFAR100 normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "cifar10_lt": # to zero mean, unit variance
        mean = (0.4989, 0.5044, 0.4926) # verified
        std = (0.2513, 0.2485, 0.2734) # verified
        print(f"Fixed CIFAR10 normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    elif FLAGS.data_norm == "cifar100_lt": # to zero mean, unit variance
        mean = (0.5228, 0.4929, 0.4420) # verified
        std = (0.2677, 0.2617, 0.2780) # verified
        print(f"Fixed CIFAR100 normalization for {FLAGS.dataset_name} mean: {mean}, std: {std}")
    else:
        raise ValueError(f"Unknown data normalization {FLAGS.data_norm}, must be one of ['adaptive', 'default', 'cifar10', 'cifar100']")

    # DATASETS/DATALOADER
    if FLAGS.dataset_name == "cifar10":
        dataset = datasets.CIFAR10(
            root=FLAGS.data_root,
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ]
            ),
        )
    elif FLAGS.dataset_name == "cifar100":
        dataset = datasets.CIFAR100(
            root=FLAGS.data_root,
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ]
            ),
        )
    elif FLAGS.dataset_name == "cifar10_lt_regacy":
        dataset = CIFAR10LTDataset_regacy(
            data_dir=FLAGS.data_root,
            split="train", 
            transform=transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ]
            ),)
    elif FLAGS.dataset_name == "cifar10_lt":
        dataset = get_real_dataset(
            dataset_name="cifar10_lt",
            split="train",
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
            data_root=FLAGS.data_root,
            imb_factor=FLAGS.imb_factor,
        )
    elif FLAGS.dataset_name == "cifar100_lt":
        dataset = get_real_dataset(
            dataset_name="cifar100_lt",
            split="train",
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
            data_root=FLAGS.data_root,
            imb_factor=FLAGS.imb_factor,
        )
    else:
        raise ValueError(f"Unknown dataset {FLAGS.dataset_name}, must be one of ['cifar10', 'cifar100', 'cifar10_lt', 'cifar100_lt']")
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers,
        drop_last=True,
    )

    datalooper = infiniteloop(dataloader)

    # MODELS
    net_model = UNetModelWrapper(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=FLAGS.num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.1,
    ).to(
        device
    )  # new dropout + bs of 128

    ema_model = copy.deepcopy(net_model)
    optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
    if FLAGS.parallel:
        print(
            "Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory."
        )
        net_model = torch.nn.DataParallel(net_model)
        ema_model = torch.nn.DataParallel(ema_model)

    # show model size
    model_size = 0
    for param in net_model.parameters():
        model_size += param.data.nelement()
    print("Model params: %.2f M" % (model_size / 1024 / 1024))

    #################################
    #            OT-CFM
    #################################

    sigma = 0.0
    if FLAGS.model == "otcfm":
        FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma, normalize_cost=FLAGS.normalize_cost)
    elif FLAGS.model == "sinkhorn_otcfm" or FLAGS.model == "sinkhorn_otwfm":
        FM = SinkhornOptimalTransportConditionalFlowMatcher(sigma=sigma, method=FLAGS.method, reg=FLAGS.reg, reg_m=(float("inf"), FLAGS.tau_b), normalize_cost=FLAGS.normalize_cost, recoupling=FLAGS.recoupling, fixed_source=FLAGS.fixed_source, fixed_target=FLAGS.fixed_target)
    elif FLAGS.model in ["icfm", "fm_ot"]: # (independent) OT path flow matching (Anonymous et al.)
        FM = ConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model in ["itfm", "fm_dif"] or FLAGS.model == "fm": # (independent) diffusion path flow matching (Anonymous et al.)
        FM = TargetConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "si": # Stochastic Interpolants flow matching (Anonymous et al.)
        FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
    else:
        raise NotImplementedError(
            f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
        )
    
    if (not "sinkhorn" in FLAGS.model) and FLAGS.weight_type != "none":
        raise ValueError(f"Weight type {FLAGS.weight_type} is not supported for {FLAGS.model}")

    
    savedir = os.path.join(FLAGS.output_dir, exp_naming(FLAGS))
    os.makedirs(savedir, exist_ok=True)

    if FLAGS.resume_step > 0:
        try:
            checkpoint_path = os.path.join(savedir, f"{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.resume_step}.pt")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            net_model.load_state_dict(checkpoint["net_model"])
            ema_model.load_state_dict(checkpoint["ema_model"])
            optim.load_state_dict(checkpoint["optim"])
            sched.load_state_dict(checkpoint["sched"])
            start_step = checkpoint["step"] + 1
            print(f"Resuming from step {start_step}")
        except:
            raise ValueError(os.path.join(savedir, f"{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{FLAGS.resume_step}.pt"))
    else:
        start_step = 0


    # JIT compile (use if needed)
    # net_model = torch.compile(net_model)


    fm_weight_list = [] # List to store fm_weight per batch
    fm_weight_min_list = []
    fm_weight_max_list = []
    fm_weight_mean_list = []
    fm_weight_std_list = []
    fm_weight_inv_std_list = []
    fm_weight_inv_mean_list = []
    
    # Make log file
    start_time = datetime.now()
    with open(os.path.join(savedir, 'log.txt'), 'w') as f:
        f.write("===== Hyperparameters(FLAGS) =====\n")
        for name in FLAGS:
            f.write(f"{name}: {getattr(FLAGS, name)}\n")
        f.write("===== Additional Information =====\n")
        f.write(f"data is normalized to (0,I) from mean: {mean}, std: {std}\n")
        f.write("==================================\n")
        f.write(f"Start Training at {start_time}\n")
        f.write('\n')

    with trange(start_step, FLAGS.total_steps, dynamic_ncols=True) as pbar:
        for step in pbar:
            optim.zero_grad()
            x1 = next(datalooper).to(device)
            x0 = torch.randn_like(x1)
            if FLAGS.model == "icfm" or FLAGS.model == "itfm" or FLAGS.model == "si" or FLAGS.model == "otcfm":
                t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) # for rest of the models
            elif "sinkhorn" in FLAGS.model:
                t, xt, ut, pi, w_u, w_v, i, j = FM.sample_location_and_conditional_flow(x0, x1) # for sinkhorn OT-CFM
                #print(f"[train_cifar10.py] mean of w_u: {w_u.mean()}, mean of w_v: {w_v.mean()}")
            else:
                raise ValueError(f"Unknown model {FLAGS.model}, must be one of ['icfm', 'itfm', 'si', 'otcfm', 'sinkhorn_otcfm', 'sinkhorn_otwfm']")
            vt = net_model(t, xt)
            # loss = torch.mean((vt - ut) ** 2) # naive FM loss
            if FLAGS.weight_type == "none":
                fm_weight = 1.0
            else:
                if FLAGS.weight_type == "inv_tnu":
                    tnu = pi.sum(dim=0)
                    tnu = tnu.reshape(tnu.size(0), 1, 1, 1).to(device) 
                    tnu = tnu / (1/x1.size(0)) # normalizaed by batch size
                    fm_weight = 1 / tnu # inverse weight (minority)
                    fm_weight = fm_weight[j] # select weight for re-coupling target samples
                    fm_weight = fm_weight ** FLAGS.weight_power_factor
                else:
                    raise ValueError(f"Unknown weight type {FLAGS.weight_type}, must be one of ['none', 'inv_u', 'inv_v']")
                if FLAGS.efm:
                    fm_weight = energy_weight(fm_weight, beta=FLAGS.beta)
                if FLAGS.save_weights_plot:
                    if step < 10:
                        fm_weight_list.append(fm_weight)
                        fm_weight_min_list.append(fm_weight.min())
                        fm_weight_max_list.append(fm_weight.max())
                        fm_weight_mean_list.append(fm_weight.mean())
                        fm_weight_std_list.append(fm_weight.std())
                        fm_weight_inv_std_list.append((tnu).std())
                        fm_weight_inv_mean_list.append((tnu).mean())
                    if len(fm_weight_list) > 100:  # Keep only the latest 100
                        fm_weight_list = fm_weight_list[-100:]
                        fm_weight_min_list = fm_weight_min_list[-100:]

            #print(f"fm_weight: {fm_weight.shape}, vt: {vt.shape}, ut: {ut.shape}")
            loss = torch.mean(((vt - ut) ** 2) * fm_weight) # weighted FM loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip)  # new
            optim.step()
            sched.step()
            ema(net_model, ema_model, FLAGS.ema_decay)  # new


            if step % (FLAGS.total_steps / 100) == 0:
                with open(os.path.join(savedir, 'log.txt'), 'a') as f:
                    f.write("|")
            
            # sample and Saving the weights
            if FLAGS.save_step > 0 and step % FLAGS.save_step == 0:
                generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal", device=device)
                generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema", device=device)
                torch.save(
                    {
                        "net_model": net_model.state_dict(),
                        "ema_model": ema_model.state_dict(),
                        "sched": sched.state_dict(),
                        "optim": optim.state_dict(),
                        "step": step,
                    },
                    os.path.join(savedir, f"{FLAGS.model}_{FLAGS.dataset_name}_weights_step_{step}.pt"),
                )
        
        print(f"fm_weight_list: {len(fm_weight_list)}, FLAGS.save_weights_plot: {FLAGS.save_weights_plot}")
        if FLAGS.save_weights_plot and len(fm_weight_list) > 0:
            print(f"Saving weights plot for {FLAGS.weight_type}")
            fm_weight_tensor = torch.cat(fm_weight_list, dim=0)
            plot_fm_weights_histogram(fm_weight_tensor.reshape(-1), FLAGS.weight_type, savedir, extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
            plot_fm_weights_histogram(torch.tensor(fm_weight_min_list), FLAGS.weight_type, savedir, data_type="min of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
            plot_fm_weights_histogram(torch.tensor(fm_weight_max_list), FLAGS.weight_type, savedir, data_type="max of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
            plot_fm_weights_histogram(torch.tensor(fm_weight_mean_list), FLAGS.weight_type, savedir, data_type="mean of batch", extra_info=f"reg{FLAGS.reg}_tau{FLAGS.tau_b}")
            
            #  weight_std   txt 
            avg_weight_std = torch.tensor(fm_weight_std_list).mean().item()
            avg_weight_mean = torch.tensor(fm_weight_mean_list).mean().item()
            avg_weight_inv_std = torch.tensor(fm_weight_inv_std_list).mean().item()
            avg_weight_inv_mean = torch.tensor(fm_weight_inv_mean_list).mean().item()   

            #   .   . (flag    flag     )
            #save_filename = f"{FLAGS.weight_type}_std_{avg_weight_std:.6f}_mean_{avg_weight_mean:.6f}_inv_std_{avg_weight_inv_std:.6f}_mean_{avg_weight_inv_mean:.6f}.txt"
            #save_filepath = os.path.join(savedir, save_filename)
            #save_value_to_txt(avg_weight_std, save_filepath, value_name="avg_weight_std_mean", extra_info_list=[fm_weight_std_list, fm_weight_mean_list, fm_weight_inv_std_list, fm_weight_inv_mean_list])

            # log.txt    
            with open(os.path.join(savedir, 'log.txt'), 'a') as f:
                f.write(f"step: {step} avg_weight_std: {avg_weight_std:.6f} avg_weight_mean: {avg_weight_mean:.6f} avg_weight_inv_std: {avg_weight_inv_std:.6f} avg_weight_inv_mean: {avg_weight_inv_mean:.6f}\n")

    with open(os.path.join(savedir, 'log.txt'), 'a') as f:
        f.write(f"\nEnd Training at {datetime.now()}\n")
        f.write(f"Total training time: {datetime.now() - start_time}\n")


if __name__ == "__main__":
    try:
        app.run(train)
    except Exception:
        # Ensure the experiment directory exists and log the full traceback
        try:
            savedir = os.path.join(FLAGS.output_dir, exp_naming(FLAGS))
            os.makedirs(savedir, exist_ok=True)
            with open(os.path.join(savedir, 'log.txt'), 'a') as f:
                f.write("\nERROR during training\n")
                f.write(traceback.format_exc())
        except Exception:
            # As a last resort, avoid masking the original exception
            pass
        raise
