import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pdb
import os
import argparse
import numpy as np
from tqdm import tqdm
import shutil

from dataset import frame_interp_dataset, embedding_dataset
from model import InterpolationModel


def test(args, model, test_loader, device):
    results = []
    rmse = 0
    n = 0
    for video, idx, video_name in tqdm(test_loader, total=len(test_loader)):
        video = video.to(device)
        pred_gap = model(video)
        pred_gap = pred_gap / pred_gap.sum(dim=1, keepdim=True) * 80
        # print(idx.shape)
        gt_gap = idx[:,1:] - idx[:,:-1]
        gt_gap = gt_gap.to(device)
        rmse += ((pred_gap - gt_gap) ** 2).sum().item()
        n += pred_gap.shape[0] * pred_gap.shape[1]
        for b in range(pred_gap.shape[0]):
            int_gap = float_to_int_gap(pred_gap[b])
            # print(pred_gap[b], int_gap)
            # pdb.set_trace()
            results.append((video_name[b], int_gap.cpu().numpy()))

    rmse = rmse / n
    return results, np.sqrt(rmse)

def float_to_int_gap(pred_gap, L=80):
    # 1. 保证最小值
    pred_gap = torch.clamp(pred_gap, min=1.0)
    
    # 2. 四舍五入
    int_gap = torch.floor(pred_gap + 0.5).long()
    
    # 3. 调整总和
    diff = L - int_gap.sum()
    # 计算小数部分
    frac = pred_gap - torch.floor(pred_gap)
    # 按小数部分排序调整
    if diff > 0:
        # 给 frac 最大的 diff 个加1
        idx = torch.argsort(frac, descending=True)[:diff]
        int_gap[idx] += 1
    elif diff < 0:
        # 给 frac 最小的 |diff| 个减1（保证 >=1）
        idx = torch.argsort(frac)
        for i in idx:
            if int_gap[i] > 1:
                int_gap[i] -= 1
            
            if int_gap.sum() == L:
                break
    try:
        assert int_gap.sum() == L
    except:
        print(pred_gap)
        pdb.set_trace()
    return int_gap

def save_results(results, save_path):
    # if not os.path.exists(save_path):
    #     os.makedirs(save_path)
    video_name_txt = os.path.join(save_path, 'video_name.txt')
    idx_txt = os.path.join(save_path, 'idx.txt')
    if os.path.exists(video_name_txt):
        os.remove(video_name_txt)
    if os.path.exists(idx_txt):
        os.remove(idx_txt)
        
    for video_name, int_gap in results:
        with open(video_name_txt, 'a') as f:
            f.write(video_name + '\n')
        with open(idx_txt, 'a') as f:
            f.write(' '.join([str(i) for i in int_gap]) + '\n')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='./interpolation_model/rpd17/model_final.pth')
    parser.add_argument('--test_path', type=str, default='/path to libero_dataset/finetune_dataset/libero_spatial_rpd17')
    parser.add_argument('--device', type=str, default='cuda')
    args = parser.parse_args()

    model = InterpolationModel()
    model.load_state_dict(torch.load(args.model_path))
    model.to(args.device)
    model.eval()

    test_dataset = frame_interp_dataset(args.test_path, is_test=False)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

    results, rmse = test(args, model, test_loader, args.device)
    print('rmse:', rmse)
    # save_results(results, args.test_path)
