import os
import pandas as pd

from learn.constrnew import const
from learn.learn import learner
import torch
import torch.optim as optim  
from pmlb import fetch_data
from sklearn.model_selection import train_test_split

from learn.nnforop import PolicyOp
from learn.nnfortwovar import PolicyTWOVAR
from learn.nnforvar import PolicyVAR
import numpy as np

def fit(X,y,seed,lib,n_epochs,batch_news,stop_reward,max_time_step,batch_size,lr,risk_factor,gamma_decay,entropy_weight):
    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=seed)
    X_train = torch.tensor(x_train, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
    X_test = torch.tensor(x_test, dtype=torch.float32)
    y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)
    mu = 0
    std = 1

    inin1 = 5
    liblenwithoutid = len(lib)
    liblen = len(lib)
    if 'id' in lib:
        liblenwithoutid = liblenwithoutid - 1
    out1 = X_train.shape[1] + 1 + liblenwithoutid * (n_epochs-1)
    policyvar = PolicyVAR(inin1,out1)  
    optimizervar = optim.Adam(policyvar.parameters(), lr=lr)  
    inin2 = out1 + 5
    out2 = liblen
    policyop = PolicyOp(inin2,out2)  
    optimizerop = optim.Adam(policyop.parameters(), lr=lr)  
    inin3 = out1 + 5 + liblen
    out3 = out1
    policytwovar = PolicyTWOVAR(inin3,out3)
    optimizertwovar = optim.Adam(policytwovar.parameters(), lr=lr)  

    beststr,r2,r3 = learner ( X_train =  X_train,
                        y_train = y_train,
                        mu = mu,
                        std = std,
                        nnvar = policyvar,
                        optimizervar = optimizervar,
                        nnop = policyop,
                        optimizerop  = optimizerop,
                        nntwovar  = policytwovar,
                        optimizertwovar = optimizertwovar,
                        n_epochs = n_epochs,
                        risk_factor = risk_factor,
                        gamma_decay = gamma_decay,
                        entropy_weight = entropy_weight,
                        stop_reward         = stop_reward,
                        lib = lib,
                        batch_news = batch_news,
                        batch_size = batch_size,
                        max_time_step = max_time_step
                        )
    r2train = const(beststr, X_train, y_train,mu,std)
    return beststr, r2train

def main():
    datasets = ["feynman_III_15_12"]
    seeds = [860]
    for dataset in datasets:
        for seed in seeds:  
            X, y = fetch_data(dataset, return_X_y=True, local_cache_dir="./datasets")
            x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.999, random_state=seed)
            beststr, r2train = fit(X=x_train,y=y_train,seed=seed,lib=['add', 'sub', 'mul', 'div', 'sin', 'cos', 'sig', 'log', 'sqrt', 'id'],n_epochs = 6,batch_news = [50, 100, 150, 200, 250, 250],stop_reward = 0.99999,max_time_step = 7,batch_size = 200,lr = 0.0025,risk_factor = 0.05,gamma_decay = 0.7,entropy_weight = 0.005)
            print(beststr)
            print(r2train)
if __name__ == '__main__':
    main()
    