import os
import sys
dirof = os.path.dirname
sys.path.insert(0, dirof(__file__))
sys.path.insert(0, dirof(dirof(dirof(__file__))))

import torch
import torch.nn as nn
import lightning as L
from torch import optim
from utils_common.utils import jpath, read_json, get_latest_checkpoint
from utils_midi import remi_utils
from m2m.evaluate import Metric
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
import math
from piano_model_moyu import *
from [anonymous] import MultiTrack, Bar
from evaluations.piano_evaluator import bar_level_pitch_wer_from_proll, bar_level_pos_wer_from_proll, bar_level_note_f1_from_proll


def load_lit_model(config):
    out_dir = jpath(config['result_root'], config['out_dir'])
    latest_version_dir, ckpt_fp = get_latest_checkpoint(out_dir)
    lit_model_class = eval(config['lit_model_class'])
    l_model = lit_model_class.load_from_checkpoint(ckpt_fp, config=config)
    l_model.config = config
    return l_model


def get_lit_model(config):
    lit_model_class = config['lit_model_class']
    print(lit_model_class)
    lit_model_class = eval(lit_model_class)
    lit_model = lit_model_class(config)
    return lit_model


class LitMoyuPiano(L.LightningModule):
    def __init__(self, config, infer=False):
        super().__init__()
        self.config = config

        model_name = config['model_class']
        model_class = eval(model_name)
        self.model = model_class(config)
        # self.model = PianoArrangeUnet(config)

        self.save_hyperparameters(config)

        self.test_results = {}

    def log_losses(self, out, split, bs):
        loss_names = ['loss_tot', 'loss_note', 'loss_pos', 'loss_bar']
        for loss_name in loss_names:
            self.log(f'{split}_{loss_name}', out[loss_name], batch_size=bs)

    def training_step(self, batch, batch_idx):
        self.model.train()

        ''' Calculate Losses '''
        out = self.model(batch)
        self.log_losses(out, 'train', len(batch))

        loss = out['loss_tot']
        return loss
    
    def validation_step(self, batch, batch_idx):
        bs = len(batch)
        self.model.eval()

        ''' Calculate Losses '''
        out = self.model(batch)
        self.log_losses(out, 'valid', bs)

        # Calculate metrics
        metrics = self.evaluation_step(batch, out)
        for k, v in metrics.items():
            self.log(f'valid_{k}', v, batch_size=bs)

    def test_step(self, batch, batch_idx):
        bs = len(batch)
        self.model.eval()

        ''' Calculate Losses '''
        out = self.model(batch)
        loss = out['loss_tot']
        self.log("test_loss", loss, batch_size=bs)

        # Calculate metrics
        metrics = self.evaluation_step(batch, out)
        for k, v in metrics.items():
            self.log(f'valid_{k}', v, batch_size=bs)

    def evaluation_step(self, batch, out):
        '''
        Calculate metrics for a single batch
        '''
        ''' Pitch WER, position WER, note F1 '''
        prolls_out = out['pred'] # [bs, pos=16, pitch=128]
        prolls_tgt = batch['piano_prolls'] # [bs, pos=16, pitch=128]
        p_wers = []
        pos_wers = []
        note_f1s = []
        for proll_out, proll_tgt in zip(prolls_out, prolls_tgt):
            # Pitch WER
            p_wer = bar_level_pitch_wer_from_proll(proll_out, proll_tgt)
            p_wers.append(p_wer)

            # Position WER
            pos_wer = bar_level_pos_wer_from_proll(proll_out, proll_tgt)
            pos_wers.append(pos_wer)

            # Note F1
            note_f1 = bar_level_note_f1_from_proll(proll_out, proll_tgt)
            note_f1s.append(note_f1)

        pos_wer = sum(pos_wers) / len(pos_wers)
        p_wer = sum(p_wers) / len(p_wers)
        note_f1 = sum(note_f1s) / len(note_f1s)

        ret = {
            'pitch_wer': p_wer,
            'pos_wer': pos_wer,
            'note_f1': note_f1,
        }

        return ret

    def calculate_metrics(self, inp_seq, tgt_seq, out_seq) -> dict:
        '''
        Calculate the metrics for a single sample
        '''
        metric = Metric()
        ret = {}

        # Pitch sequence similarity
        pitch_wer = metric.calculate_pitch_wer(out_seq, tgt_seq)
        ret['pitch_wer'] = pitch_wer

        pitch_iou = metric.calculate_pitch_iou(out_seq, tgt_seq)
        ret['pitch_iou'] = pitch_iou

        # Groove similarity
        pos_wer, pos_sor = metric.calculate_groove_wer_sor_mbar(out_seq, tgt_seq)
        ret['pos_wer'] = pos_wer

        pos_iou = metric.calculate_groove_iou_mbar(out_seq, tgt_seq)
        ret['pos_iou'] = pos_iou

        return ret

    def configure_optimizers(self):
        optimizer = optim.AdamW(
            self.parameters(), 
            lr=self.config['lr'],
            weight_decay=self.config['weight_decay']
        )

        if self.config['lr_scheduler'] == 'none':
            ret = {"optimizer": optimizer}

        elif self.config['lr_scheduler'] == 'linear':
            # Linear scheduler
            max_steps = self.num_training_steps()
            scheduler = transformers.get_linear_schedule_with_warmup(
                optimizer=optimizer,
                num_warmup_steps=self.config['warmup_steps'],
                num_training_steps=max_steps,
            )
            ret = {"optimizer": optimizer, "lr_scheduler": scheduler},
        
        elif self.config['lr_scheduler'] == 'anneal':
            # Annealing
            scheduler = ReduceLROnPlateauPatch(
                optimizer,
                mode='min',
                factor=0.5,
                patience=self.config['lr_anneal_patience'],
                verbose=True
            )

            ret = {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "valid_loss",
                },
            }
        
        return ret
    
    def num_training_steps(self) -> int:
        """Get number of training steps"""
        if self.trainer.max_steps > -1:
            return self.trainer.max_steps

        self.trainer.fit_loop.setup_data()
        dataset_size = len(self.trainer.train_dataloader)
        num_steps = dataset_size * self.trainer.max_epochs

        return num_steps

    def get_step_per_epoch(self):
        if self.trainer.train_dataloader is not None:
            return len(self.trainer.train_dataloader)
        self.trainer.fit_loop.setup_data()
        return len(self.trainer.train_dataloader)
    
    def on_validation_epoch_end(self):
        scheduler = self.lr_schedulers()

        # If the selected scheduler is a ReduceLROnPlateau scheduler.
        # LR anneal update
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(self.trainer.callback_metrics["valid_loss"])


class ReduceLROnPlateauPatch(ReduceLROnPlateau, _LRScheduler):
    def get_lr(self):
        return [ group['lr'] for group in self.optimizer.param_groups ]
