# This source code is provided for the purposes of scientific reproducibility
# under the following limited license from Element AI Inc. The code is an
# implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
# expansion analysis for interpretable time series forecasting,
# https://arxiv.org/abs/1905.10437). The copyright to the source code is
# licensed under the Creative Commons - Attribution-NonCommercial 4.0
# International license (CC BY-NC 4.0):
# https://creativecommons.org/licenses/by-nc/4.0/.  Any commercial use (whether
# for the benefit of third parties or internally in production) requires an
# explicit license. The subject-matter of the N-BEATS model and associated
# materials are the property of Element AI Inc. and may be subject to patent
# protection. No license to patents is granted hereunder (whether express or
# implied). Copyright 2020 Element AI Inc. All rights reserved.

"""
M4 Summary
"""
from collections import OrderedDict

import numpy as np
import pandas as pd

from data_provider.m4 import M4Dataset
from data_provider.m4 import M4Meta
import os


def group_values(values, groups, group_name):
    # Return a ragged array (dtype=object) with per-series, non‑NaN slices
    return np.array([v[~np.isnan(v)] for v in values[groups == group_name]], dtype=object)


def mase(forecast, insample, outsample, frequency):
    return np.mean(np.abs(forecast - outsample)) / np.mean(np.abs(insample[:-frequency] - insample[frequency:]))


def smape_2(forecast, target):
    denom = np.abs(target) + np.abs(forecast)
    # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway.
    denom[denom == 0.0] = 1.0
    return 200 * np.abs(forecast - target) / denom


def mape(forecast, target):
    denom = np.abs(target)
    # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway.
    denom[denom == 0.0] = 1.0
    return 100 * np.abs(forecast - target) / denom


class M4Summary:
    def __init__(self, file_path, root_path, seed_suffix: str = ""):
        self.file_path = file_path
        self.seed_suffix = seed_suffix  # e.g., "_seed42"; empty string for legacy paths
        self.training_set = M4Dataset.load(training=True, dataset_file=root_path)
        self.test_set = M4Dataset.load(training=False, dataset_file=root_path)
        self.naive_path = os.path.join(root_path, 'submission-Naive2.csv')

    def evaluate_single(self, group_name: str):
        """Evaluate sMAPE, OWA, MAPE, MASE for a single M4 group.
        Returns a tuple: (smape_g, owa_g, mape_g, mase_g)
        """
        import pandas as pd
        group_name = str(group_name)
        # Load naive2 once
        naive2_forecasts = pd.read_csv(self.naive_path).values[:, 1:].astype(np.float32)
        # Produce a ragged array (dtype=object) to hold per-series non‑NaN slices robustly
        naive2_forecasts = np.array([v[~np.isnan(v)] for v in naive2_forecasts], dtype=object)

        # Paths and data per group
        file_name = os.path.join(self.file_path, f"{group_name}{self.seed_suffix}_forecast.csv")
        if not os.path.exists(file_name):
            raise FileNotFoundError(file_name)

        model_forecast = pd.read_csv(file_name).values
        naive2_forecast = group_values(naive2_forecasts, self.test_set.groups, group_name)
        target = group_values(self.test_set.values, self.test_set.groups, group_name)
        frequency = self.training_set.frequencies[self.test_set.groups == group_name][0]
        insample = group_values(self.training_set.values, self.test_set.groups, group_name)

        mase_model = np.mean([mase(forecast=model_forecast[i],
                                   insample=insample[i],
                                   outsample=target[i],
                                   frequency=frequency) for i in range(len(model_forecast))])
        mase_naive2 = np.mean([mase(forecast=naive2_forecast[i],
                                     insample=insample[i],
                                     outsample=target[i],
                                     frequency=frequency) for i in range(len(model_forecast))])

        smape_naive2 = np.mean(smape_2(naive2_forecast, target))
        smape_model = np.mean(smape_2(forecast=model_forecast, target=target))
        mape_model = np.mean(mape(forecast=model_forecast, target=target))

        owa_g = (mase_model / mase_naive2 + smape_model / smape_naive2) / 2

        # Round to 3 decimals to match evaluate()
        def r(x):
            return float(np.round(x, 3))

        return r(smape_model), r(owa_g), r(mape_model), r(mase_model)

    def evaluate(self):
        """
        Evaluate forecasts using M4 test dataset.

        :param forecast: Forecasts. Shape: timeseries, time.
        :return: sMAPE and OWA grouped by seasonal patterns.
        """
        grouped_owa = OrderedDict()

        naive2_forecasts = pd.read_csv(self.naive_path).values[:, 1:].astype(np.float32)
        naive2_forecasts = np.array([v[~np.isnan(v)] for v in naive2_forecasts], dtype=object)

        model_mases = {}
        naive2_smapes = {}
        naive2_mases = {}
        grouped_smapes = {}
        grouped_mapes = {}
        for group_name in M4Meta.seasonal_patterns:
            file_name = os.path.join(self.file_path, f"{group_name}{self.seed_suffix}_forecast.csv")
            if os.path.exists(file_name):
                model_forecast = pd.read_csv(file_name).values

            naive2_forecast = group_values(naive2_forecasts, self.test_set.groups, group_name)
            target = group_values(self.test_set.values, self.test_set.groups, group_name)
            # all timeseries within group have same frequency
            frequency = self.training_set.frequencies[self.test_set.groups == group_name][0]
            insample = group_values(self.training_set.values, self.test_set.groups, group_name)

            model_mases[group_name] = np.mean([mase(forecast=model_forecast[i],
                                                    insample=insample[i],
                                                    outsample=target[i],
                                                    frequency=frequency) for i in range(len(model_forecast))])
            naive2_mases[group_name] = np.mean([mase(forecast=naive2_forecast[i],
                                                     insample=insample[i],
                                                     outsample=target[i],
                                                     frequency=frequency) for i in range(len(model_forecast))])

            naive2_smapes[group_name] = np.mean(smape_2(naive2_forecast, target))
            grouped_smapes[group_name] = np.mean(smape_2(forecast=model_forecast, target=target))
            grouped_mapes[group_name] = np.mean(mape(forecast=model_forecast, target=target))

        grouped_smapes = self.summarize_groups(grouped_smapes)
        grouped_mapes = self.summarize_groups(grouped_mapes)
        grouped_model_mases = self.summarize_groups(model_mases)
        grouped_naive2_smapes = self.summarize_groups(naive2_smapes)
        grouped_naive2_mases = self.summarize_groups(naive2_mases)
        for k in grouped_model_mases.keys():
            grouped_owa[k] = (grouped_model_mases[k] / grouped_naive2_mases[k] +
                              grouped_smapes[k] / grouped_naive2_smapes[k]) / 2

        def round_all(d):
            return dict(map(lambda kv: (kv[0], np.round(kv[1], 3)), d.items()))

        return round_all(grouped_smapes), round_all(grouped_owa), round_all(grouped_mapes), round_all(
            grouped_model_mases)

    def evaluate_groups(self, groups):
        """Evaluate sMAPE/OWA/MAPE/MASE restricted to a subset of M4 groups.
        Returns four dicts with metrics per group and an additional key 'SubsetAverage'
        containing the weighted average across the provided groups.
        """
        import pandas as pd
        groups = [str(g) for g in groups]
        naive2_forecasts = pd.read_csv(self.naive_path).values[:, 1:].astype(np.float32)
        naive2_forecasts = np.array([v[~np.isnan(v)] for v in naive2_forecasts])

        model_mases = {}
        naive2_smapes = {}
        naive2_mases = {}
        grouped_smapes = {}
        grouped_mapes = {}

        # Compute metrics per requested group
        for g in groups:
            file_name = os.path.join(self.file_path, f"{g}{self.seed_suffix}_forecast.csv")
            if not os.path.exists(file_name):
                raise FileNotFoundError(file_name)
            model_forecast = pd.read_csv(file_name).values
            naive2_forecast = group_values(naive2_forecasts, self.test_set.groups, g)
            target = group_values(self.test_set.values, self.test_set.groups, g)
            frequency = self.training_set.frequencies[self.test_set.groups == g][0]
            insample = group_values(self.training_set.values, self.test_set.groups, g)

            model_mases[g] = np.mean([mase(forecast=model_forecast[i],
                                           insample=insample[i],
                                           outsample=target[i],
                                           frequency=frequency) for i in range(len(model_forecast))])
            naive2_mases[g] = np.mean([mase(forecast=naive2_forecast[i],
                                             insample=insample[i],
                                             outsample=target[i],
                                             frequency=frequency) for i in range(len(model_forecast))])

            naive2_smapes[g] = np.mean(smape_2(naive2_forecast, target))
            grouped_smapes[g] = np.mean(smape_2(forecast=model_forecast, target=target))
            grouped_mapes[g] = np.mean(mape(forecast=model_forecast, target=target))

        # Weighted averages across subset
        def group_count(group_name):
            return len(np.where(self.test_set.groups == group_name)[0])

        def weighted_avg(d):
            total = 0.0
            count = 0
            for k, v in d.items():
                c = group_count(k)
                total += v * c
                count += c
            return total / count if count > 0 else float('nan')

        # OWA over subset
        subset_owa = (weighted_avg(model_mases) / weighted_avg(naive2_mases) +
                      weighted_avg(grouped_smapes) / weighted_avg(naive2_smapes)) / 2

        # Round results
        def rdict(d):
            return {k: float(np.round(v, 3)) for k, v in d.items()}

        smape_out = rdict(grouped_smapes)
        mape_out = rdict(grouped_mapes)
        mase_out = rdict(model_mases)

        smape_out['SubsetAverage'] = float(np.round(weighted_avg(grouped_smapes), 3))
        mape_out['SubsetAverage'] = float(np.round(weighted_avg(grouped_mapes), 3))
        mase_out['SubsetAverage'] = float(np.round(weighted_avg(model_mases), 3))
        owa_out = {k: float(np.round((model_mases[k] / naive2_mases[k] + grouped_smapes[k] / naive2_smapes[k]) / 2, 3)) for k in groups}
        owa_out['SubsetAverage'] = float(np.round(subset_owa, 3))

        return smape_out, owa_out, mape_out, mase_out

    def summarize_groups(self, scores):
        """
        Re-group scores respecting M4 rules.
        :param scores: Scores per group.
        :return: Grouped scores.
        """
        scores_summary = OrderedDict()

        def group_count(group_name):
            return len(np.where(self.test_set.groups == group_name)[0])

        weighted_score = {}
        for g in ['Yearly', 'Quarterly', 'Monthly']:
            weighted_score[g] = scores[g] * group_count(g)
            scores_summary[g] = scores[g]

        others_score = 0
        others_count = 0
        for g in ['Weekly', 'Daily', 'Hourly']:
            others_score += scores[g] * group_count(g)
            others_count += group_count(g)
        weighted_score['Others'] = others_score
        scores_summary['Others'] = others_score / others_count

        average = np.sum(list(weighted_score.values())) / len(self.test_set.groups)
        scores_summary['Average'] = average

        return scores_summary
