import numpy as np

from typing import Type, Optional
from dataclasses import dataclass
from sklearn.metrics import r2_score
from tqdm import tqdm
from .models import TimeSeriesModel
from .data import DataGenerator, TimeSeriesDataset
from .util import cartProd


@dataclass
class ModelClass:
    declarator: Type[TimeSeriesModel]
    args: dict
    crossvalArgs: Optional[dict] = None
    tuneSeparately: bool = False


def getPredictions(
    modelClass: ModelClass,
    dataset: TimeSeriesDataset,
    numStepsTrain: int,
    windowSize: int,
    numWindows: int,
):  
    if modelClass.declarator.oneShot():
        print("This is a one-shot model, so we'll just train it once.")
        trainData = dataset.noisyObservations[:numStepsTrain]
        fittedModel = modelClass.declarator(**modelClass.args).fit(trainData)
        return fittedModel.predict(numWindows * windowSize)
    else:
        predictions = []
        fittedModel = None
        for windowIdx in tqdm(range(numWindows)):
            if modelClass.declarator.updatable() and windowIdx > 0:
                assert fittedModel is not None
                fittedModel.update(
                    dataset.noisyObservations[
                        numStepsTrain + (windowIdx - 1) * windowSize :
                    ][:windowSize]
                )
            else:
                trainData = dataset.noisyObservations[
                    : numStepsTrain + windowIdx * windowSize, :
                ]
                fittedModel = modelClass.declarator(**modelClass.args).fit(trainData)
            predictions.append(fittedModel.predict(windowSize))
        return np.concatenate(predictions)


class ExperimentManager:
    def __init__(self, dataGenerator: DataGenerator):
        self.dataGenerator = dataGenerator
        self.modelClasses = {}

    def register(self, name: str, modelClass: ModelClass):
        self.modelClasses[name] = modelClass

    def run(
        self,
        numStepsTrain: int,
        windowSize: int,
        numWindowsVal: int,
        numWindowTest: int,
    ):
        numStepsEval = windowSize * numWindowsVal
        numStepsTest = windowSize * numWindowTest
        totalSteps = numStepsTrain + numStepsEval + numStepsTest
        dataset = self.dataGenerator.sample(totalSteps)
        numSeries = dataset.noisyObservations.shape[1]

        predictions = {}
        valGroundTruth = dataset.groundTruth[
            numStepsTrain : numStepsTrain + numStepsEval
        ]
        for name, modelClass in self.modelClasses.items():
            print("Tuning hyperparameters for", name)

            if modelClass.tuneSeparately:
                bestArgs, bestScore = [{} for _ in range(numSeries)], np.full(
                    numSeries, -np.inf
                )
            else:
                bestArgs, bestScore = {}, -np.inf

            if modelClass.crossvalArgs is not None:
                for candArgs in cartProd(modelClass.crossvalArgs):
                    print("Trying", candArgs)
                    candClass = ModelClass(
                        modelClass.declarator, {**modelClass.args, **candArgs}
                    )
                    preds = getPredictions(
                        candClass, dataset, numStepsTrain, windowSize, numWindowsVal
                    )

                    allScores = r2_score(
                        valGroundTruth, preds, multioutput="raw_values"
                    )

                    if modelClass.tuneSeparately:
                        assert isinstance(allScores, np.ndarray)
                        assert isinstance(bestScore, np.ndarray)
                        bestArgs = [
                            candArgs if newScore > oldScore else oldArg
                            for oldArg, oldScore, newScore in zip(
                                bestArgs, bestScore, allScores
                            )
                        ]
                        bestScore = np.maximum(allScores, bestScore)
                        print(bestArgs)
                    else:
                        score = np.mean(allScores)
                        if score > bestScore:
                            print("New best is", score, "with", candArgs)
                            bestArgs, bestScore = candArgs, score
                print("Best args for", name, "are", bestArgs)
            else:
                print("No hyperparameters to tune for", name)

            if modelClass.tuneSeparately:
                bestArgs = {argName: [args[argName] for args in bestArgs] for argName in bestArgs[0]}
            assert isinstance(bestArgs, dict)

            predictions[name] = getPredictions(
                ModelClass(modelClass.declarator, {**modelClass.args, **bestArgs}),
                dataset,
                numStepsTrain + numStepsEval,
                windowSize,
                numWindowTest,
            )
        predictions["groundTruth"] = dataset.groundTruth[numStepsTrain + numStepsEval :]
        predictions["noisyObservations"] = dataset.noisyObservations[
            numStepsTrain + numStepsEval :
        ]
        return predictions
