import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pdb
import os
import argparse
import numpy as np
from tqdm import tqdm
import shutil

from dataset import frame_interp_dataset, embedding_dataset
from model import InterpolationModel


def train(model, train_loader, val_loader, epochs, lr, device, save_dir):
    # Loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(epochs):
        model.train()
        n_total = 0
        train_sse = 0
        mean_gap = 0
        for i, (video, idx, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
            video = video.to(device)
            idx = idx.to(device)
            gap = idx[:, 1:] - idx[:, :-1]
            B, T = gap.shape
            pred_gap = model.forward_embedding(video)
            pred_gap = pred_gap / pred_gap.sum(dim=1, keepdim=True) * 80
            loss = criterion(pred_gap, gap)

            final_pred = pred_gap

            mean_gap += final_pred.sum().item()
            # Backward pass and optimization    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_sse += (final_pred - gap).pow(2).sum().item()
            n_total += B * T

        train_mse = train_sse / n_total
        train_rmse = np.sqrt(train_mse)
        print(f"Epoch {epoch+1}/{epochs}, MSE: {train_mse:.4f}, RMSE: {train_rmse:.4f}, mean_gap: {mean_gap/n_total}")    

        # Validation
        model.eval()
        with torch.no_grad():
            vn_total=0
            val_loss = 0
            val_sse = 0
            mean_gap = 0
            for video, idx, _ in val_loader:
                video = video.to(device)
                idx = idx.to(device)
                gap = idx[:, 1:] - idx[:, :-1]
                B, T = gap.shape
                pred_gap = model.forward_embedding(video)
                pred_gap = pred_gap / pred_gap.sum(dim=1, keepdim=True) * 80
                loss = criterion(pred_gap, gap) 
                final_pred = pred_gap
                val_loss += loss.item() * (B * T)
                vn_total += B * T
                val_sse += (final_pred - gap).pow(2).sum().item()

        val_mse = val_sse / vn_total
        val_rmse = np.sqrt(val_mse)
        print(f"Epoch {epoch+1}/{epochs}, MSE: {val_mse:.4f}, RMSE: {val_rmse:.4f}")

        # Save model
        # if epoch % 10 == 0:
            # torch.save(model.state_dict(), os.path.join(save_dir, f"model_epoch_{epoch+1}.pth"))

    # torch.save(model.state_dict(), os.path.join(save_dir, "model_final.pth"))
    print("Training complete. Model saved.")

def encode_video(model, dataloader, device, save_dir):
    print("SAVEING video cache to:", save_dir)
    model.eval()
    with torch.no_grad():
        for video, idx, video_name in tqdm(dataloader, total=len(dataloader)):
            vid_embedding = model.encode_video(video.to(device))
            for b in range(len(video_name)):
                n = video_name[b].split("/")[-1].split(".")[0]
                emb = vid_embedding[b].cpu()
                torch.save(emb, os.path.join(save_dir, f"{n}.pt"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--save_dir", type=str, default="./interpolation_model/rpd33")
    parser.add_argument("--regenerate_cache", action="store_true")
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--n_workers", type=int, default=12)
    parser.add_argument("--train_data", type=str, default="/path to libero_dataset/finetune_dataset/libero_90_rpd33")
    parser.add_argument("--train_cache", type=str, default="/path to libero_dataset/interpolation_model/cache/libero_90_rpd33")
    parser.add_argument("--val_data", type=str, default="/path to libero_dataset/finetune_dataset/libero_10_rpd33")
    parser.add_argument("--val_cache", type=str, default="/path to libero_dataset/interpolation_model/cache/libero_10_rpd33")
    args = parser.parse_args()

    # Create model
    model = InterpolationModel().to(args.device)

    ## if no cache, use frame_interp_dataset and encode_video
    if (not os.path.exists(args.train_cache)) or args.regenerate_cache:
        shutil.rmtree(args.train_cache, ignore_errors=True)
        os.makedirs(args.train_cache, exist_ok=True)
        train_dataset = frame_interp_dataset(args.train_data)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers, pin_memory=True)
        encode_video(model, train_loader, args.device, args.train_cache)

    if (not os.path.exists(args.val_cache)) or args.regenerate_cache:
        shutil.rmtree(args.val_cache, ignore_errors=True)
        os.makedirs(args.val_cache, exist_ok=True)
        val_dataset = frame_interp_dataset(args.val_data)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers, pin_memory=True)
        encode_video(model, val_loader, args.device, args.val_cache)
    
    # Load embedding dataset
    train_dataset = embedding_dataset(args.train_data, args.train_cache)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=True)

    val_dataset = embedding_dataset(args.val_data, args.val_cache)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers, pin_memory=True)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    # Train model
    train(model, train_loader, val_loader, args.epochs, args.lr, args.device, args.save_dir)