import os
import tqdm
import functools
import ipdb
import torch
from torch import nn, Tensor

import numpy as np
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import wandb

from model_toy import ToyMLP
from utils import get_args
from dataset.dataset import *

# flow_matching
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.loss import MixturePathGeneralizedKL

def tau_to_str(tau_value):
    """
    tau_value -> str
    e.g. 0.001 -> '001', 0.01 -> '01', 0.1 -> '1', 1.0 -> '1'
    """
    tau_str = str(tau_value)
    if '.' in tau_str:
        decimal_part = tau_str.split('.')[1]
        decimal_part = decimal_part.rstrip('0')
        return decimal_part if decimal_part else '0'
    else:
        return tau_str

def train(args, pretrained_model, data_loader, info, start_epoch=0, tau_str="001", n_str="5000"):
    optimizer = Adam(pretrained_model.parameters(), lr=1e-3)
    scheduler = PolynomialConvexScheduler(n=1)  # linear scheduler
    path = MixtureDiscreteProbPath(scheduler=scheduler) # mixture discrete path
    loss_fn = MixturePathGeneralizedKL(path=path) # loss function

    # info
    vocab_size = info["vocab_size"]
    k = info["K"]
    tau = info["tau"]  # early stopping threshold
    n_epochs = 2000*k
    tqdm_epoch = tqdm.trange(start_epoch, n_epochs)
    for epoch in tqdm_epoch:
        avg_loss = 0.
        num_items = 0
        # training behavior
        for x in data_loader:
            # x shape: (B, l)
            x_0 = torch.randint_like(x, high=vocab_size)
            t = torch.rand(x.shape[0]).to(args.device) * (1 - tau)
            # sample probability path
            path_sample = path.sample(t=t, x_0=x_0, x_1=x)
            logits = pretrained_model(x=path_sample.x_t, t=path_sample.t)
            loss = loss_fn(logits=logits, x_1=x, x_t=path_sample.x_t, t=path_sample.t)
            optimizer.zero_grad()
            loss.backward()    
            optimizer.step()
            avg_loss += loss.item() * x.shape[0]
            num_items += x.shape[0]
        tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
        # Update the checkpoint after each epoch of training.
        if epoch % 100 == 99 and args.save_model:
            save_dir = os.path.join("./models", "toy_n{}tau{}_{}".format(n_str, tau_str, k))
            os.makedirs(save_dir, exist_ok=True) 
            torch.save(pretrained_model.state_dict(), os.path.join(save_dir, "ckpt.pth"))
        wandb.log({"loss": avg_loss / num_items, "epoch": epoch, "k": k, "tau": tau, "n": n_str})

def main(args):
    for dir in ["./models", "./toylogs"]:
        if not os.path.exists(dir):
            os.makedirs(dir)
    if not os.path.exists(os.path.join("./models", "toy_{}".format(1e5))):
        os.makedirs(os.path.join("./models", "toy_{}".format(1e5)))
    
    n_values = [100000]
    tau_values = [0.25]
    
    wandb.init(project="toy_loss_monitoring", name="toy_continue_training_multi_n")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    vocab_size = args.vocab_size
    
    for n in n_values:
        n_str = str(n)
        batch_size = 512
        
        print(f"\n{'='*80}")
        print(f"Training with n={n}, batch_size={batch_size}")
        print(f"{'='*80}\n")
        
        for tau in tau_values:
            tau_str = tau_to_str(tau)
            
            print(f"\n{'-'*60}")
            print(f"Processing tau = {tau} (tau_str = '{tau_str}')")
            print(f"{'-'*60}\n")
            
            for k in [1, 2, 3, 4, 5]:
                info = {
                    "vocab_size": vocab_size,
                    "K": k,
                    "tau": tau,
                    "n": n
                }
                
                checkpoint_path = os.path.join("./models", f"toy_n{n_str}tau{tau_str}_{k}", f"ckpt.pth")
                start_epoch = 0
                
                pretrained_model = ToyMLP(vocab_size=vocab_size, hidden_dim=256, length=3*k).to(args.device)
                
                if os.path.exists(checkpoint_path):
                    print(f"Loading checkpoint from {checkpoint_path}")
                    pretrained_model.load_state_dict(torch.load(checkpoint_path, map_location=args.device))
                    start_epoch = 0
                    print(f"Resuming training for n={n}, tau={tau}, K={k} from epoch {start_epoch}...")
                else:
                    print(f"Checkpoint not found at {checkpoint_path}")
                    print(f"Starting training from scratch for n={n}, tau={tau}, K={k}...")
                
                np_data = generate_3k_discrete_data(n=n, K=k)
                dataset = NumpyDataset(np_data)
                data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
                
                train(args, pretrained_model, data_loader, info, start_epoch=start_epoch, tau_str=tau_str, n_str=n_str)
                print(f"Finished training for n={n}, tau={tau}, K={k}\n")

if __name__ == "__main__":
    args = get_args()
    main(args)