
import csv
import torch
import numpy as np
import logging

# from torch_mir_eval.separation import bss_eval_sources
from ..losses import (
    PITLossWrapper,
    pairwise_neg_sisdr,
    pairwise_neg_snr,
    singlesrc_neg_sisdr,
)

logger = logging.getLogger(__name__)


class SPlitMetricsTracker:
    def __init__(self, save_file: str = ""):
        self.one_all_snrs = []
        self.one_all_snrs_i = []
        self.one_all_sisnrs = []
        self.one_all_sisnrs_i = []
        self.two_all_snrs = []
        self.two_all_snrs_i = []
        self.two_all_sisnrs = []
        self.two_all_sisnrs_i = []
        csv_columns = [
            "snt_id",
            "one_snr",
            "one_snr_i",
            "one_si-snr",
            "one_si-snr_i",
            "two_snr",
            "two_snr_i",
            "two_si-snr",
            "two_si-snr_i",
        ]
        self.results_csv = open(save_file, "w")
        self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns)
        self.writer.writeheader()
        self.pit_sisnr = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
        self.pit_snr = PITLossWrapper(pairwise_neg_snr, pit_from="pw_mtx")

    def __call__(self, mix, clean, estimate, key):
        _, ests_np = self.pit_snr(
            estimate.unsqueeze(0), clean.unsqueeze(0), return_ests=True
        )
        # sisnr
        two_sisnr = self.pit_sisnr(ests_np[:, 0:2], clean.unsqueeze(0)[:, 0:2])
        one_sisnr = self.pit_sisnr(
            ests_np[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
        )
        mix = torch.stack([mix] * clean.shape[0], dim=0)
        two_sisnr_baseline = self.pit_sisnr(
            mix.unsqueeze(0)[:, 0:2], clean.unsqueeze(0)[:, 0:2]
        )
        one_sisnr_baseline = self.pit_sisnr(
            mix.unsqueeze(0)[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
        )
        two_sisnr_i = two_sisnr - two_sisnr_baseline
        one_sisnr_i = one_sisnr - one_sisnr_baseline
        # sdr
        two_snr = self.pit_snr(ests_np[:, 0:2], clean.unsqueeze(0)[:, 0:2])
        one_snr = self.pit_snr(
            ests_np[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
        )
        two_snr_baseline = self.pit_snr(
            mix.unsqueeze(0)[:, 0:2], clean.unsqueeze(0)[:, 0:2]
        )
        one_snr_baseline = self.pit_snr(
            mix.unsqueeze(0)[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
        )
        two_snr_i = two_snr - two_snr_baseline
        one_snr_i = one_snr - one_snr_baseline

        row = {
            "snt_id": key,
            "one_snr": -one_snr.item(),
            "one_snr_i": -one_snr_i.item(),
            "one_si-snr": -one_sisnr.item(),
            "one_si-snr_i": -one_sisnr_i.item(),
            "two_snr": -two_snr.item(),
            "two_snr_i": -two_snr_i.item(),
            "two_si-snr": -two_sisnr.item(),
            "two_si-snr_i": -two_sisnr_i.item(),
        }
        self.writer.writerow(row)
        # Metric Accumulation
        self.one_all_snrs.append(-one_snr.item())
        self.one_all_snrs_i.append(-one_snr_i.item())
        self.one_all_sisnrs.append(-one_sisnr.item())
        self.one_all_sisnrs_i.append(-one_sisnr_i.item())
        self.two_all_snrs.append(-two_snr.item())
        self.two_all_snrs_i.append(-two_snr_i.item())
        self.two_all_sisnrs.append(-two_sisnr.item())
        self.two_all_sisnrs_i.append(-two_sisnr_i.item())

    def final(self,):
        row = {
            "snt_id": "avg",
            "one_snr": np.array(self.one_all_snrs).mean(),
            "one_snr_i": np.array(self.one_all_snrs_i).mean(),
            "one_si-snr": np.array(self.one_all_sisnrs).mean(),
            "one_si-snr_i": np.array(self.one_all_sisnrs_i).mean(),
            "two_snr": np.array(self.two_all_snrs).mean(),
            "two_snr_i": np.array(self.two_all_snrs_i).mean(),
            "two_si-snr": np.array(self.two_all_sisnrs).mean(),
            "two_si-snr_i": np.array(self.two_all_sisnrs_i).mean(),
        }
        self.writer.writerow(row)
        # logger.info("Mean SISNR is {}".format(row["si-snr"]))
        # logger.info("Mean SISNRi is {}".format(row["si-snr_i"]))
        # logger.info("Mean SDR is {}".format(row["sdr"]))
        # logger.info("Mean SDRi is {}".format(row["sdr_i"]))
        self.results_csv.close()
