# pylint: disable=W0401, C0413, E0602
import os
import sys
import torch
import math
import sacrebleu
import copy
import librosa
from pytorch_lightning import seed_everything
from transformers.tokenization_utils_base import BatchEncoding
from transformers.feature_extraction_utils import BatchFeature
import pandas as pd
import time

# 设置 BASE_DIR 变量，其值是脚本文件所在目录的上三级目录的路径
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
from models.simul7.pl_module import *
from config.parse_yaml_args import parse_args_and_yaml

seed_everything(32)


def deep_to_device(obj, device, half=False):
    if isinstance(obj, torch.Tensor):
        obj = obj.to(device)
        if obj.dtype == torch.float32 and half:
            obj = obj.half()
        return obj
    if isinstance(obj, (dict, BatchEncoding, BatchFeature)):
        for k, v in obj.items():
            obj[k] = deep_to_device(v, device, half)
        return obj
    if isinstance(obj, (list, tuple)):
        return [deep_to_device(v, device, half) for v in obj]

    return obj


target_device = torch.device("cuda:0")
# target_device = torch.device('cpu')
# torch.set_default_device(target_device)

os.chdir(os.path.dirname(os.path.abspath(__file__)))

# 解析.yaml配置文件
cfg = parse_args_and_yaml(config_path="PLACEHOLDER")

cfg["data_cfg"]["train"]["mt"]["batch_size"] = 2
cfg["data_cfg"]["train"]["st"]["batch_size"] = 20
cfg["data_cfg"]["validation"]["batch_size"] = 20
cfg["data_cfg"]["num_worker"] = 0
cfg["use_deepspeed"] = False
cfg["data_cfg"]["train"]["mt"]["paths"] = cfg["data_cfg"]["train"]["mt"]["paths"][:2]
cfg["data_cfg"]["train"]["st"]["paths"] = cfg["data_cfg"]["train"]["st"]["paths"][:2]

checkpoint_path = "PLACEHOLDER"


# module = Simul2Module.load_from_checkpoint(
#     checkpoint_path=checkpoint_path,
#     cfg=cfg,
#     map_location=target_device,
# ).half().eval().to(target_device)
module = Simul7Module(cfg=cfg).half().eval().to(target_device)


optimizer, scheduler = configure_optimizer_schedular(
    cfg=module.cfg, params_generator=module.named_parameters, num_training_steps=100000
)


# inputs = extractor(
#     [audio],
#     sampling_rate=16000,
#     return_tensors="pt",
# )
# if torch.all(inputs['attention_mask'][:, -1] == 0):
#             inputs['attention_mask'] = inputs['attention_mask'][:, :-1]
#             inputs['input_features'] = inputs['input_features'][:, :-1]

# inputs['chunk_mask'] = stream_mask
model = module.model
tokenizer = module.tokenizer

train_loader = module.train_dataloader()
val_loader = module.val_dataloader()
hyps = []
refs = []
with torch.no_grad():
    model.train()
    for i, b in enumerate(train_loader):
        if i >= 1:
            break
        b = deep_to_device(b, target_device, half=True)
        loss = module.training_step(b, 0)
        print(loss, "\n")
    model.eval()
    for i, b in enumerate(val_loader):
        if i >= 1:
            break
        b = deep_to_device(b, target_device, half=True)
        # print(text)
        res = module.validation_step(b, 0, 0)
        print(res)
