from copy import deepcopy
import torch
import random
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.combined_loader import CombinedLoader
from torch.utils.data import DataLoader
from torchmetrics.text import SacreBLEUScore, CharErrorRate, WordErrorRate

from transformers import AutoTokenizer
# from whisper_normalizer.basic import BasicTextNormalizer

# normalizer = BasicTextNormalizer()

from .conf_opt_sch import configure_optimizer_schedular
from .dataset import STCollator, STDataset
from .net.model import Whisper
# from .criterion import STCriterion
from .Normalizer.utils import text_norm, remove_punctuation

def unique(l: list):
    if len(l) == 0:
        return []
    out = []
    low = 0
    for i in range(1, len(l)):
        if l[i] != l[i-1]:
            out.append([l[i-1], (low, i)])
            low = i
    out.append([l[-1], (low, len(l))])
    return out

def get_slice(item, slices):
    item = deepcopy(item)
    chunk_mask = item.pop('chunk_mask', None)
    slice_dict = {key: item[key][slices] for key in item.keys()}
    if chunk_mask is not None:
        slice_dict['chunk_mask'] = chunk_mask
    return slice_dict


class DigsstModule(LightningModule):
    def __init__(self, cfg) -> None:
        super().__init__()
        model_cfg = cfg['model_cfg']
        self.model = Whisper(model_cfg).train()

        self.tokenizer = AutoTokenizer.from_pretrained(model_cfg['Whisper']['huggingface_path'], trust_remote_code=True, use_fast=True)
        self.st_collate_fn = STCollator(self.tokenizer, n_mels=model_cfg['Whisper']['n_mels'])

        self.cfg = cfg

    def forward(self, **kwargs):
        return self.model(**kwargs)

    def on_train_start(self) -> None:
        self.lr_schedulers().step()


    def on_load_checkpoint(self, checkpoint) -> None:
        """Fix the checkpoint loading issue for deepspeed."""
        if self._trainer is not None:
            return
        if "state_dict" in checkpoint:
            return
        state_dict = checkpoint['module']
        checkpoint['state_dict'] = state_dict
        return

    def training_step(self, batch, batch_idx):
        
        out = self.model(**batch['inputs'])
        loss = out['loss']
        self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True,
                 on_epoch=False, sync_dist=True, batch_size=out['bsz'])
        return loss


    def validation_step(self, batch, batch_id, dataloader_idx=None):
        out = self.model(**batch['inputs'])
        self.log("val/loss", out['loss'], on_step=False, prog_bar=True, logger=True,
                 on_epoch=True, sync_dist=True, batch_size=out['bsz'])

        return {
            'loss': out['loss'],
        }


    def configure_optimizers(self):
        # 配置优化器和学习率调度器，根据传入的配置和模型参数生成优化器和调度器实例，并返回它们。
        optimizer, scheduler = configure_optimizer_schedular(
            cfg=self.cfg,
            params_generator=self.named_parameters,
            num_training_steps=self.trainer.estimated_stepping_batches
        )
        self.optimizer = optimizer
        self.scheduler = scheduler

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

    def train_dataloader(self):
        dataset_cfg = self.cfg['data_cfg']['train']

        st_dataset = STDataset(
            dataset_cfg['st']['paths'],
            train=True
            )
        st_loader = DataLoader(
            st_dataset,
            batch_size=dataset_cfg['st']['batch_size'],
            shuffle=True,
            num_workers=self.cfg['data_cfg']['num_worker'],
            collate_fn=self.st_collate_fn,
            drop_last=True,
        )
        return st_loader


    def val_dataloader(self):
        dataset_cfg = self.cfg['data_cfg']['validation']
        dataset_paths, target_split = [], []
        for split, path in dataset_cfg['splits'].items():
            dataset_paths.append(path)
            target_split.append(split)
        dataset = STDataset(
            dataset_paths,
            train=False,
            target_split=target_split
            )

        dataloader = DataLoader(
            dataset,
            batch_size=dataset_cfg['batch_size'],
            drop_last=False,
            shuffle=False,
            collate_fn=self.st_collate_fn,
            num_workers=self.cfg['data_cfg']['num_worker'],
        )

        return dataloader

    def test_dataloader(self):
        pass
