import argparse
from tqdm import tqdm
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import Ridge, RidgeClassifier
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from scipy.stats import spearmanr

import torch
import pandas as pd
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset

torch.manual_seed(0)

from sequence_models.utils import Tokenizer
from sequence_models.convolutional import MaskedConv1d
from sequence_models.datasets import CSVDataset

AAINDEX_ALPHABET = 'ARNDCQEGHILKMFPSTWYV'


parser = argparse.ArgumentParser()
parser.add_argument('data_fpath', type=str, help='file path to data directory')
parser.add_argument('--out_fpath', required=False)
parser.add_argument('--bin', action='store_true')
parser.add_argument('--low_n', action='store_true')
parser.add_argument('--scale', type=bool, default=False)
parser.add_argument('--solver', type=str, default='lsqr')
parser.add_argument('--max_iter', type=float, default=1e6)
parser.add_argument('--tol', type=float, default=1e-4)
args = parser.parse_args()


# grab data
data_fpath = args.data_fpath
df = pd.read_csv(data_fpath)
df = pd.read_csv(data_fpath)
if 'miser' in data_fpath:
    df['sequence'] = df['gapped']
    if not args.low_n:
        df['split'] = df['split2']
ds_train = CSVDataset(df=df, split='train', outputs=['tgt'])
ds_test = CSVDataset(df=df, split='test', outputs=['tgt'])


# tokenize train data
all_train = list(ds_train)[:]
X_train = [i[0] for i in all_train]
y_train = np.array([i[1] for i in all_train])[:, None]
if 'miser' in data_fpath:
    AAINDEX_ALPHABET += '-'
tokenizer = Tokenizer(AAINDEX_ALPHABET) # tokenize
X_train = [torch.tensor(tokenizer.tokenize(i)).view(-1, 1) for i in X_train]


# tokenize test data
all_test = list(ds_test)[:]
idx = np.arange(len(all_test))
np.random.seed(0)
np.random.shuffle(idx)
X_test = [all_test[i][0] for i in idx]
y_test = np.array([all_test[i][1] for i in idx])[:, None]
tokenizer = Tokenizer(AAINDEX_ALPHABET) # tokenize
X_test = [torch.tensor(tokenizer.tokenize(i)).view(-1, 1) for i in X_test]


# padding
maxlen_train = max([len(i) for i in X_train])
maxlen_test = max([len(i) for i in X_test])
maxlen = max([maxlen_train, maxlen_test])

X_train = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", 0.) for i in X_train]
X_train_enc = [] # ohe
for i in X_train:
    i_onehot = torch.FloatTensor(maxlen, len(AAINDEX_ALPHABET))
    i_onehot.zero_()
    i_onehot.scatter_(1, i, 1)
    X_train_enc.append(i_onehot)
X_train_enc = np.array([np.array(i.view(-1)) for i in X_train_enc]) # flatten

X_test = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", 0.) for i in X_test]
X_test_enc = [] # ohe
for i in X_test:
    i_onehot = torch.FloatTensor(maxlen, len(AAINDEX_ALPHABET))
    i_onehot.zero_()
    i_onehot.scatter_(1, i, 1)
    X_test_enc.append(i_onehot)
X_test_enc = np.array([np.array(i.view(-1)) for i in X_test_enc]) # flatten


# scale X
if args.scale:
    scaler = StandardScaler()
    y_train = scaler.fit_transform(y_train)
    y_test = scaler.transform(y_test)

print('Parameters...')
print('Solver: %s, MaxIter: %s, Tol: %s' % (args.solver, args.max_iter, args.tol))

if args.low_n:
    reps = 16
    df = pd.DataFrame(columns=['n_train', 'mse', 'rho'])
    bs = np.array([16, 32, 64, 128, 256, 512, 1024, 2048, 4096])
    np.random.seed(32)
    with tqdm(total=len(bs) * reps) as pbar:
        for rep in range(reps):
            idx = np.arange(len(y_train)).astype(int)
            np.random.shuffle(idx)
            for b in bs:
                lr = Ridge(solver=args.solver, tol=args.tol, max_iter=args.max_iter,)
                lr.fit(X_train_enc[idx[:b]], y_train[idx[:b]])
                preds = lr.predict(X_test_enc)
                rho = spearmanr(y_test, preds).correlation
                mse = mean_squared_error(y_test, preds)
                df.loc[len(df), ['n_train', 'rho', 'mse']] = [b, rho, mse]
                pbar.update(1)
    df.to_csv(args.out_fpath, index=False)
else:
    print('Training...')
    if args.bin:
        lr = RidgeClassifier(solver=args.solver, tol=args.tol, max_iter=args.max_iter)
    else:
        lr = Ridge(solver=args.solver, tol=args.tol, max_iter=args.max_iter, )
    lr.fit(X_train_enc, y_train)
    preds = lr.predict(X_test_enc)
    if args.bin:
        roc = roc_auc_score(y_test, preds)
        print('TEST AUCROC: ', roc)
    else:
        mse = mean_squared_error(y_test, preds)
        print('TEST MSE: ', mse)
        print('TEST RHO: ', spearmanr(y_test, preds).correlation)