import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pdb
import os
import argparse
import numpy as np
from tqdm import tqdm
import shutil

from dataset import frame_interpolation_dataset
from interpolator import Interpolator
import cv2

def train(model, train_loader, val_loader, epochs, lr, device, save_dir):
    # Loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(epochs):
        model.train()
        n_total = 0
        train_sse = 0
        for i, (keyframe1, keyframe2, middle_frame, time_interval) in tqdm(enumerate(train_loader), total=len(train_loader)):
            keyframe1 = keyframe1.to(device)
            keyframe2 = keyframe2.to(device)
            middle_frame = middle_frame.to(device)
            time_interval = time_interval.to(device)
            pred_middle = model(keyframe1, keyframe2, time_interval)

            # midd_img_0 = pred_middle[0, :, :, :].detach().cpu().numpy()
            # print(midd_img_0.shape)
            # midd_img_0 = np.transpose(midd_img_0, (1, 2, 0)) * 255 
            # cv2.imwrite('test.png', midd_img_0)
            # pdb.set_trace()
            loss = criterion(pred_middle, middle_frame)

            BS = keyframe1.shape[0]
            train_sse += loss.item() * BS
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            n_total += BS
            if i % 100 == 0:
                avg_loss = train_sse / n_total
                print(
                    f"Epoch {epoch+1}/{epochs}, "
                    f"Batch {i+1}/{len(train_loader)}, "
                    f"Loss: {loss.item():.4f}, "
                    f"Avg Loss: {avg_loss:.4f}"
                )

        train_mse = train_sse / n_total
        train_rmse = np.sqrt(train_mse)
        print(f"TRAIN Epoch {epoch+1}/{epochs}, MSE: {train_mse:.4f}, RMSE: {train_rmse:.4f}")    

        # Validation
        if epoch % 1 == 0:
            model.eval()
            with torch.no_grad():
                n_total = 0
                val_sse = 0
                mean_gap = 0
                for i, (keyframe1, keyframe2, middle_frame, time_interval) in tqdm(enumerate(val_loader), total=len(val_loader)):
                    keyframe1 = keyframe1.to(device)
                    keyframe2 = keyframe2.to(device)
                    middle_frame = middle_frame.to(device)
                    time_interval = time_interval.to(device)

                    pred_middle = model(keyframe1, keyframe2, time_interval)
                    loss = criterion(pred_middle, middle_frame)

                    val_sse += loss.item() * keyframe1.shape[0]
                    n_total += keyframe1.shape[0]

            val_mse = val_sse / n_total
            val_rmse = np.sqrt(val_mse)
            print(f"VAL Epoch {epoch+1}/{epochs}, MSE: {val_mse:.4f}, RMSE: {val_rmse:.4f}")

        # Save model
        if epoch % 10 == 0:
            torch.save(model.state_dict(), os.path.join(save_dir, f"model_epoch_{epoch+1}.pth"))

    torch.save(model.state_dict(), os.path.join(save_dir, "model_final.pth"))
    print("Training complete. Model saved.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--down_sample", type=int, default=1)
    parser.add_argument("--save_dir", type=str, default="./ckpts")
    parser.add_argument("--pretrained_ckpt", type=str, default="path to libero_dataset/ECCV_interp_model/film_net_fp32.pt")
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--n_workers", type=int, default=16)
    parser.add_argument("--train_data", type=str, default="path to libero_dataset/finetune_dataset/libero_90_picture")
    parser.add_argument("--train_keyframe", type=str, default="path to libero_dataset/finetune_dataset/libero_90_rpd17")
    parser.add_argument("--val_data", type=str, default="path to libero_dataset/finetune_dataset/libero_10_picture")
    parser.add_argument("--val_keyframe", type=str, default="path to libero_dataset/finetune_dataset/libero_10_rpd17")
    args = parser.parse_args()

    # Create model
    # model = Interpolator()
    # model.load_state_dict(torch.load(args.pretrained_ckpt, map_location="cpu"))
    model = torch.jit.load(args.pretrained_ckpt, map_location="cpu")
    jit_state_dict = model.state_dict()
    model = Interpolator()
    model.load_state_dict(jit_state_dict)
    # freeze model.extract and model.predict_flow
    # for param in model.extract.parameters():
    #     param.requires_grad = False
    # for param in model.predict_flow.parameters():
    #     param.requires_grad = False

    model.to(args.device)
    ##total trainable parameters
    print("Total trainable parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Load embedding dataset
    train_dataset = frame_interpolation_dataset(args.train_data, args.train_keyframe, downsample=args.down_sample, data_size=30000)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=True)

    val_dataset = frame_interpolation_dataset(args.val_data, args.val_keyframe, downsample=args.down_sample, data_size=1000)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers, pin_memory=True)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    # Train model
    train(model, train_loader, val_loader, args.epochs, args.lr, args.device, args.save_dir)