import common_org
import common_renewal as common
import os
import numpy as np
import random
import torch
from random import SystemRandom
from scipy.optimize import fsolve

from datasets.speech_commands import get_data
from parse import parse_args

BASELINE_MODELS = ["ncde", "odernn", "dt", "decay", "gruode", "odernn_forecasting"]

args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

def main(
    manual_seed=args.seed,
    device="cuda",
    max_epochs=200,
    *,  # training parameters
    model_name=args.model,
    hidden_channels=args.h_channels,
    hidden_hidden_channels=args.hh_channels,
    num_hidden_layers=args.layers,
    lr=args.lr,
    slope_check=args.slope_check,
    soft=args.soft,
    timewise=args.timewise,
    attention_channel=args.attention_channel,
    attention_attention_channel=args.attention_attention_channel,  # model parameters
    step_mode=args.step_mode,
    dry_run=False,
    c1=args.c1,
    c2=args.c2,
    rtol=args.rtol,
    atol=args.atol,
    gpu=args.gpu,
    p,m,
    **kwargs
):  # kwargs passed on to cdeint
    np.random.seed(manual_seed)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    # if you are using GPU
    torch.cuda.manual_seed(manual_seed)
    torch.cuda.manual_seed_all(manual_seed)
    torch.random.manual_seed(manual_seed)
    batch_size = 1024
    lr = lr * (batch_size / 32) # lr = 0.00032
    
    PATH = os.path.dirname(os.path.abspath(__file__))

    intensity_data = True if model_name in ("odernn", "dt", "decay") else False
    (
        times,
        train_dataloader,
        val_dataloader,
        test_dataloader,
    ) = get_data(intensity_data, batch_size)


    def equations(lambdas, T, p_target, m_target):
        lambda_1, lambda_2 = lambdas
        p = lambda_1 / (lambda_1 + lambda_2) * (1 - np.exp(-(lambda_1 + lambda_2) * T))
        m = lambda_1 * lambda_2 / (lambda_1 + lambda_2) * T - lambda_1 * lambda_2 / (lambda_1 + lambda_2)**2 * (1 - np.exp(-(lambda_1 + lambda_2) * T))
        return [p - p_target, m - m_target]

    if p == 0 or m == 0:
        lambda1 = 0
        lambda2 = 0
    else :
        lambda1, lambda2 = fsolve(equations, [0.1, 0.1], args=(times[-1], p, m))

    input_channels = 1 + (1 + intensity_data) * 20
    output_channels = 10
    experiment_id = int(SystemRandom().random() * 100000)
    file = PATH + "/" + "Sepsis_h_prime/" + f"{experiment_id}.npy"
    SAVED_PATH = PATH + "/" + "Sepsis_h_prime/"
    if not os.path.exists(SAVED_PATH):
        os.makedirs(SAVED_PATH)

    if model_name in BASELINE_MODELS:
        make_model = common_org.make_model(
            model_name,
            input_channels,
            output_channels,
            hidden_channels,
            hidden_hidden_channels,
            num_hidden_layers,
            use_intensity=False,
            initial=True,
        )
    else:  # attention models
        make_model = common.make_model(
            model_name,
            input_channels,
            output_channels,
            hidden_channels,
            hidden_hidden_channels,
            attention_channel,
            attention_attention_channel,
            num_hidden_layers,
            rtol=args.rtol,
            atol=args.atol,
            use_intensity=False,
            slope_check=slope_check,
            soft=soft,
            timewise=timewise,
            file=file,
            initial=True,
            lambda1=lambda1, lambda2=lambda2
        )
    # make_model = common.make_model(model_name, input_channels, 10, hidden_channels, hidden_hidden_channels,
    #                                num_hidden_layers, use_intensity=False, initial=True)
    # import pdb ; pdb.set_trace()
    experiments = "speech_commands" + str(manual_seed)
    if model_name in BASELINE_MODELS:

        def new_make_model():
            model, regularise = make_model()
            model.linear.weight.register_hook(lambda grad: 100 * grad)
            model.linear.bias.register_hook(lambda grad: 100 * grad)
            return model, regularise

    else:

        def new_make_model():
            model, regularise1, regularise2 = make_model()
            model.linear.weight.register_hook(lambda grad: 100 * grad)
            model.linear.bias.register_hook(lambda grad: 100 * grad)
            return model, regularise1, regularise2

    name = None if dry_run else "speech_commands"
    num_classes = 10
    if model_name in BASELINE_MODELS:

        return common.main(
            name,
            times,
            train_dataloader,
            val_dataloader,
            test_dataloader,
            device,
            new_make_model,
            num_classes,
            max_epochs,
            lr,
            kwargs,
            step_mode=True,
        )
    else:
        # import pdb ; pdb.set_trace()
        return common.main(
            experiments,
            model_name,
            name,
            times,
            train_dataloader,
            val_dataloader,
            test_dataloader,
            device,
            new_make_model,
            num_classes,
            max_epochs,
            lr,
            slope_check,
            kwargs,
            step_mode=step_mode,
            c1=args.c1,
            c2=args.c2,
        )


def run_all(device, model_names=("ncde", "odernn", "dt", "decay", "gruode")):
    model_kwargs = dict(
        ncde=dict(hidden_channels=90, hidden_hidden_channels=40, num_hidden_layers=4),
        odernn=dict(
            hidden_channels=128, hidden_hidden_channels=64, num_hidden_layers=4
        ),
        dt=dict(
            hidden_channels=160, hidden_hidden_channels=None, num_hidden_layers=None
        ),
        decay=dict(
            hidden_channels=160, hidden_hidden_channels=None, num_hidden_layers=None
        ),
        gruode=dict(
            hidden_channels=160, hidden_hidden_channels=None, num_hidden_layers=None
        ),
    )
    for model_name in model_names:
        # Hyperparameters selected as what ODE-RNN did best with.
        for _ in range(5):
            main(device, model_name=model_name, **model_kwargs[model_name])


if __name__ == "__main__":
    for _ in range(5):
        for p in np.linspace(0.1, 0.5, num=5):
            for m in [5, 10, 50, 100]:
                main(p=p, m=m)
