from copy import deepcopy
import torch
import random
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.combined_loader import CombinedLoader
from torch.utils.data import DataLoader
from torchmetrics.text import SacreBLEUScore, CharErrorRate, WordErrorRate

from transformers import AutoTokenizer
# from whisper_normalizer.basic import BasicTextNormalizer

# normalizer = BasicTextNormalizer()

from .conf_opt_sch import configure_optimizer_schedular
from .dataset import STCollator, STDataset
from .net.model import Whisper
from .net_teacher.model import WhisperTextTeacher
from .criterion import STCriterion
from .Normalizer.utils import text_norm, remove_punctuation

def unique(l: list):
    if len(l) == 0:
        return []
    out = []
    low = 0
    for i in range(1, len(l)):
        if l[i] != l[i-1]:
            out.append([l[i-1], (low, i)])
            low = i
    out.append([l[-1], (low, len(l))])
    return out

def get_slice(item, slices):
    item = deepcopy(item)
    chunk_mask = item.pop('chunk_mask', None)
    slice_dict = {key: item[key][slices] for key in item.keys()}
    if chunk_mask is not None:
        slice_dict['chunk_mask'] = chunk_mask
    return slice_dict


class Simul7Module(LightningModule):
    def __init__(self, cfg) -> None:
        super().__init__()
        model_cfg = cfg['model_cfg']
        self.model = Whisper(model_cfg).train()
        self.teacher = WhisperTextTeacher(model_cfg).eval()

        self.tokenizer = AutoTokenizer.from_pretrained(model_cfg['Whisper']['huggingface_path'], trust_remote_code=True, use_fast=True)
        self.st_collate_fn = STCollator(self.tokenizer, n_mels=model_cfg['Whisper']['n_mels'])
        self.st_criterion = STCriterion(self.model, self.teacher)

        self.cfg = cfg

        bleu_metrics = torch.nn.ModuleDict()
        asr_metrics = torch.nn.ModuleDict()
        for k in cfg['data_cfg']['validation']['splits']:
            src, tgt = k.split('-')
            tokenize = 'zh' if tgt == 'zh' else "13a"
            bleu_metrics[k] = SacreBLEUScore(tokenize=tokenize)
            asr_metrics[src] = CharErrorRate() if src == 'zh' else WordErrorRate()

        self.val_metrics = torch.nn.ModuleDict({
            'st': bleu_metrics,
            # 'mt': deepcopy(bleu_metrics),
            'asr': asr_metrics,
        })

    def forward(self, **kwargs):
        return self.model(**kwargs)

    def on_train_start(self) -> None:
        self.lr_schedulers().step()
        # for task in ['st', 'mt']:
        for _, v in self.val_metrics['st'].items():
            v.set_dtype(torch.float32)

    def on_load_checkpoint(self, checkpoint) -> None:
        """Fix the checkpoint loading issue for deepspeed."""
        if self._trainer is not None:
            return
        if "state_dict" in checkpoint:
            return
        state_dict = checkpoint['module']
        checkpoint['state_dict'] = state_dict
        return

    def training_step(self, batch, batch_idx):
        
        # mt_out = self.mt_criterion(batch['mt'])
        # st_out = self.st_criterion(batch['st'])
        st_out = self.st_criterion(batch, global_step=self.global_step)
        # loss = mt_out['loss'] + st_out['loss']
        loss = st_out['loss']
        self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True,
                 on_epoch=False, sync_dist=True, batch_size=st_out['bsz'])
        # for out in [mt_out, st_out]:
        # for out in [st_out]:
        for k, v in st_out['log'].items():
            self.log(f"train/{k}", v, on_step=True, prog_bar=False, logger=True,
                on_epoch=False, sync_dist=True, batch_size=st_out['bsz'])
        return loss


    def validation_step(self, batch, batch_id, dataloader_idx=None):
        out = self.st_criterion(batch)
        self.log("val/loss", out['loss'], on_step=False, prog_bar=True, logger=True,
                 on_epoch=True, sync_dist=True, batch_size=out['bsz'])

        for k, v in out['log'].items():
            self.log(f"val/{k}", v, on_step=False, prog_bar=False, logger=True,
                on_epoch=True, sync_dist=True, batch_size=out['bsz'])


        # texts_asr, texts_ast = [], []
        # for wav, splits, src_text, tgt_text in zip(batch['wavs'], batch['splits'], batch['src_txt'], batch['tgt_txt']):
        #     src_lang, tgt_lang = splits.split('-')
        #     _, text_asr = self.model.generate(wav, tokenizer=self.tokenizer)
        #     _, text_ast = self.model.generate(wav, tokenizer=self.tokenizer, task='translate', lang_id=tgt_lang)
        #     texts_asr.append(text_asr)
        #     texts_ast.append(text_ast)
        #     text_asr = remove_punctuation(text_norm(text_asr, src_lang))
        #     text_ast = remove_punctuation(text_norm(text_ast, tgt_lang))
        #     src_text = remove_punctuation(text_norm(src_text, src_lang))
        #     tgt_text = remove_punctuation(text_norm(tgt_text, tgt_lang))
        #     self.val_metrics['asr'][src_lang]([text_asr], [src_text])
        #     self.val_metrics['st'][splits]([text_ast], [[tgt_text]])
        # for s in set(batch['splits']):
        #     self.log(f"val/st/{s}", self.val_metrics['st'][s], on_step=False, on_epoch=True)
        # for s in set([splits.split('-')[0] for splits in batch['splits']]):
        #     self.log(f"val/asr/{s}", self.val_metrics['asr'][s], on_step=False, on_epoch=True)


        # out_info = '\n\n'.join([f'Split: {s}\nSrc: {src}\nLabel: {l}\nST: {st}\nASR: {asr}'
        #             for s, src, l, st, asr in zip(batch['splits'], batch['src_txt'], batch['tgt_txt'], texts_ast, texts_asr)])

        # print(out_info)

        return {
            'loss': out['loss'],
            'logs': out['log'],
            # 'res': out_info
        }

    def test_step(self, batch, batch_idx, dataloader_idx=None):
        pass

    # def on_before_optimizer_step(self, optimizer):
    #     # example to inspect gradient information in tensorboard
    #     if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
    #         for k, v in self.named_parameters():
    #             print(k, v.grad)
                # self.logger.experiment.add_histogram(
                #     tag=k, values=v.grad, global_step=self.trainer.global_step
                # )

    def configure_optimizers(self):
        # 配置优化器和学习率调度器，根据传入的配置和模型参数生成优化器和调度器实例，并返回它们。
        optimizer, scheduler = configure_optimizer_schedular(
            cfg=self.cfg,
            params_generator=self.named_parameters,
            num_training_steps=self.trainer.estimated_stepping_batches
        )
        self.optimizer = optimizer
        self.scheduler = scheduler

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

    def train_dataloader(self):
        dataset_cfg = self.cfg['data_cfg']['train']

        st_dataset = STDataset(
            dataset_cfg['st']['paths'],
            train=True
            )
        st_loader = DataLoader(
            st_dataset,
            batch_size=dataset_cfg['st']['batch_size'],
            shuffle=True,
            num_workers=self.cfg['data_cfg']['num_worker'],
            collate_fn=self.st_collate_fn,
            drop_last=True,
        )
        return st_loader
        # mt_dataset = MTDataset(
        #     dataset_cfg['mt']['paths'],
        #     train=True
        #     )

        # mt_loader = DataLoader(
        #     mt_dataset,
        #     batch_size=dataset_cfg['mt']['batch_size'],
        #     shuffle=True,
        #     num_workers=self.cfg['data_cfg']['num_worker'],
        #     collate_fn=self.mt_collate_fn,
        #     drop_last=True,
        # )

        # combined_loader = CombinedLoader(
        #     {'st':st_loader, 'mt': mt_loader},
        #     'max_size_cycle'
        #     )
        # return combined_loader

    def val_dataloader(self):
        dataset_cfg = self.cfg['data_cfg']['validation']
        dataset_paths, target_split = [], []
        for split, path in dataset_cfg['splits'].items():
            dataset_paths.append(path)
            target_split.append(split)
        dataset = STDataset(
            dataset_paths,
            train=False,
            target_split=target_split
            )

        dataloader = DataLoader(
            dataset,
            batch_size=dataset_cfg['batch_size'],
            drop_last=False,
            shuffle=False,
            collate_fn=self.st_collate_fn,
            num_workers=self.cfg['data_cfg']['num_worker'],
        )

        return dataloader

    def test_dataloader(self):
        pass
