import gemclus.gemini

from douglas import Douglas
from torchdouglas import TorchDouglas


def run(args, X):
    if args.torch:
        gemini = args.distance + "_" + args.mode
        if "mi" == args.distance:
            gemini = "mmd_ova"

        model = TorchDouglas(n_clusters=args.n_clusters, n_cuts=args.n_cuts, temperature=args.temperature,
                             n_epochs=args.n_epochs, batch_size=args.batch_size, learning_rate=args.learning_rate,
                             gemini=gemini, verbose=True)
    else:
        if args.distance == "mmd":
            if args.mode == "ova":
                gemini = gemclus.gemini.MMDOvA()
            else:
                gemini = gemclus.gemini.MMDOvO()
        elif args.distance == "wasserstein":
            if args.mode == "ova":
                gemini = gemclus.gemini.WassersteinOvA()
            else:
                gemini = gemclus.gemini.WassersteinOvO()
        else:
            gemini = gemclus.gemini.MI()

        model = Douglas(n_clusters=args.n_clusters, n_cuts=args.n_cuts, temperature=args.temperature,
                        n_epochs=args.n_epochs, batch_size=args.batch_size, learning_rate=args.learning_rate,
                        gemini=gemini, verbose=True)

    douglas_pred = model.fit_predict(X)

    return douglas_pred