from copy import deepcopy
import torch
import random
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.combined_loader import CombinedLoader
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import numpy as np
import torch.nn.functional as F
# from whisper_normalizer.basic import BasicTextNormalizer

# normalizer = BasicTextNormalizer()

from .conf_opt_sch import configure_optimizer_schedular
from .dataset import Collator, TTSDataset
from .net.model import Qwen2LM
from .criterion import Criterion
from .Normalizer.utils import text_norm, remove_punctuation

def unique(l: list):
    if len(l) == 0:
        return []
    out = []
    low = 0
    for i in range(1, len(l)):
        if l[i] != l[i-1]:
            out.append([l[i-1], (low, i)])
            low = i
    out.append([l[-1], (low, len(l))])
    return out

def get_slice(item, slices):
    item = deepcopy(item)
    chunk_mask = item.pop('chunk_mask', None)
    slice_dict = {key: item[key][slices] for key in item.keys()}
    if chunk_mask is not None:
        slice_dict['chunk_mask'] = chunk_mask
    return slice_dict


class CozyvoiceModule(LightningModule):
    def __init__(self, cfg) -> None:
        super().__init__()
        model_cfg = cfg['model_cfg']
        self.model = Qwen2LM(model_cfg['model_path']).train()
        self.collate_fn = Collator(model_cfg['tokenizer_path'])
        self.criterion = Criterion(self.model)

        self.cfg = cfg

    def forward(self, **kwargs):
        return self.model(**kwargs)

    def on_train_start(self) -> None:
        self.lr_schedulers().step()


    def training_step(self, batch, batch_idx):
        
        out = self.criterion(batch, global_step=self.global_step)
        # loss = mt_out['loss'] + st_out['loss']
        loss = out['loss']
        self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True,
                 on_epoch=False, sync_dist=True, batch_size=out['bsz'])
        for k, v in out['log'].items():
            self.log(f"train/{k}", v, on_step=True, prog_bar=False, logger=True,
                on_epoch=False, sync_dist=True, batch_size=out['bsz'])
        return loss


    def validation_step(self, batch, batch_id, dataloader_idx=None):
        out = self.criterion(batch)
        self.log("val/loss", out['loss'], on_step=False, prog_bar=True, logger=True,
                 on_epoch=True, sync_dist=True, batch_size=out['bsz'])

        for k, v in out['log'].items():
            self.log(f"val/{k}", v, on_step=False, prog_bar=False, logger=True,
                on_epoch=True, sync_dist=True, batch_size=out['bsz'])

        return {
            'loss': out['loss'],
            'logs': out['log'],
            # 'res': out_info
        }

    def test_step(self, batch, batch_idx, dataloader_idx=None):
        pass


    def configure_optimizers(self):
        # 配置优化器和学习率调度器，根据传入的配置和模型参数生成优化器和调度器实例，并返回它们。
        optimizer, scheduler = configure_optimizer_schedular(
            cfg=self.cfg,
            params_generator=self.named_parameters,
            num_training_steps=self.trainer.estimated_stepping_batches
        )
        self.optimizer = optimizer
        self.scheduler = scheduler

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

    def train_dataloader(self):
        dataset_cfg = self.cfg['data_cfg']['train']

        dataset = TTSDataset(
            dataset_cfg['paths'],
            self.cfg['model_cfg']['onnx_path'],
            train=True
            )
        loader = DataLoader(
            dataset,
            batch_size=dataset_cfg['batch_size'],
            shuffle=True,
            num_workers=self.cfg['data_cfg']['num_worker'],
            collate_fn=self.collate_fn,
            drop_last=True,
        )
        return loader

    def val_dataloader(self):
        dataset_cfg = self.cfg['data_cfg']['validation']
        dataset = TTSDataset(
            dataset_cfg['paths'],
            self.cfg['model_cfg']['onnx_path'],
            train=False,
            )

        dataloader = DataLoader(
            dataset,
            batch_size=dataset_cfg['batch_size'],
            drop_last=False,
            shuffle=False,
            collate_fn=self.collate_fn,
            num_workers=self.cfg['data_cfg']['num_worker'],
        )

        return dataloader

    def test_dataloader(self):
        pass
