import torch
from custom_datasets.embedding import EmbeddingDataset
from models import REPVLM
from torch.utils.data import DataLoader
import os
import argparse
from utils import ema
from utils.seed import set_seed

# Imports for flow matching
from flow_matching.path import GeodesicProbPath
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.utils.manifolds import Sphere

# Imports for DDP
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler


def setup_ddp():
    """Initializes the distributed process group."""
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


def cleanup_ddp():
    """Cleans up the distributed process group."""
    dist.destroy_process_group()


def main(args):
    setup_ddp()
    rank = int(os.environ["LOCAL_RANK"])
    world_size = dist.get_world_size()
    device = rank
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    set_seed(args.seed)

    dataset = EmbeddingDataset(
        f'embeddings/{args.dataset}/image.pth',
        f'embeddings/{args.dataset}/text.pth',)

    # split dataset into train and val
    val_size = int(0.01 * len(dataset))
    train_size = len(dataset) - val_size
    trainset, valset = torch.utils.data.random_split(
        dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42))
    valloader = DataLoader(
        valset, batch_size=256, shuffle=False,
        num_workers=16, pin_memory=True, prefetch_factor=2)

    # Batch size is now per-GPU. Effective batch size is batch_size * world_size.
    batch_size = 2048 // world_size

    sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=False,
        num_workers=16, pin_memory=True, prefetch_factor=2, sampler=sampler)

    model = REPVLM().to(device)
    model = torch.compile(model)
    model = DDP(model, device_ids=[device])

    # Corrected EMA model initialization
    ema_model = REPVLM().to(device)
    # Access the underlying model's state_dict through the .module attribute
    ema_model.load_state_dict(model.module._orig_mod.state_dict())
    for param in ema_model.parameters():
        param.detach_()
    ema_model.eval()

    lr = 1e-4
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    num_steps = int(4e5) # Number of Training Steps
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=1e-3, total_iters=int(1e4))

    manifold = Sphere()
    path = GeodesicProbPath(scheduler=CondOTScheduler(), manifold=manifold)

    model.train()
    avg_loss = 0.
    save_every = int(1e3)
    step = 0
    for epoch in range(int(1e5)):
        sampler.set_epoch(epoch)

        for batch in dataloader:
            z_image, z_text = batch
            z_image, z_text = z_image.to(device), z_text.to(device)

            num_image = z_image.shape[0]
            num_text = z_text.shape[0]

            x1 = torch.cat([z_image, z_text], dim=0)
            x0 = torch.randn_like(x1)

            # project to the manifold
            x0 = manifold.projx(x0)
            x1 = manifold.projx(x1)

            t = torch.rand(x1.size(0), device=device)
            samples = path.sample(t=t, x_0=x0, x_1=x1)
            xt, target = samples.x_t, samples.dx_t

            # construct the labels
            y = torch.cat([
                torch.zeros(num_image, dtype=torch.long),
                torch.ones(num_text, dtype=torch.long)], dim=0).to(device)

            optimizer.zero_grad(set_to_none=True)

            # with autocast(device_type='cuda', dtype=torch.bfloat16):
            pred = model(xt, t, y)
            loss = torch.pow(pred - target, 2).mean()

            loss.backward()
            optimizer.step()
            scheduler.step()

            # update EMA
            ema_model = ema(model.module._orig_mod, ema_model, decay=0.9999)

            step += 1
            # Reduce loss from all processes for consistent logging
            dist.all_reduce(loss, op=dist.ReduceOp.AVG)
            avg_loss += loss.detach().item()

            if rank == 0 and step % save_every == 0:
                avg_loss /= save_every

                # evaluate on val set
                model.eval()
                val_loss = 0.
                with torch.no_grad():
                    for x_i, x_t in valloader:
                        x_i, x_t = x_i.to(device), x_t.to(device)
                        num_image = x_i.shape[0]
                        num_text = x_t.shape[0]

                        vx1 = torch.cat([x_i, x_t], dim=0).to(device)
                        vx0 = torch.randn_like(vx1)

                        # project to the manifold
                        vx0 = manifold.projx(vx0)
                        vx1 = manifold.projx(vx1)

                        vt = torch.rand(vx1.size(0), device=device)
                        vsamples = path.sample(t=vt, x_0=vx0, x_1=vx1)

                        vxt, vtarget = vsamples.x_t, vsamples.dx_t

                        y = torch.cat([
                            torch.zeros(num_image, dtype=torch.long),
                            torch.ones(num_text, dtype=torch.long)], dim=0).to(device)

                        vpred = ema_model(vxt, vt, y)
                        vloss = torch.pow(vpred - vtarget, 2).mean()
                        val_loss += vloss.item()
                val_loss /= len(valloader)
                model.train()

                print(f"Step {step+1}: train loss = {avg_loss:.6f}, val loss = {val_loss: .6f}", flush=True)
                avg_loss = 0.

                save_path = f"checkpoints/{args.dataset}/repvlm/{args.seed}"
                os.makedirs(save_path, exist_ok=True)
                # Save the underlying model's state dict
                torch.save(ema_model.state_dict(), f"{save_path}/model.pth")
                dist.barrier()
            else:
                dist.barrier()

            if step >= num_steps:
                break
        if step >= num_steps:
            break

    cleanup_ddp()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument(
        '--seed', type=int, default=0, help='random seed for initialization')
    args = parser.parse_args()

    main(args)
