import os.path
from tqdm import tqdm
from typing import List, Any
from datetime import datetime
import argparse
from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score
import pandas as pd

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

from sequence_models.utils import Tokenizer
from sequence_models.convolutional import MaskedConv1d
from sequence_models.datasets import CSVDataset


AAINDEX_ALPHABET = 'ARNDCQEGHILKMFPSTWYV'


class ASCollater(object):
    def __init__(self, alphabet: str, tokenizer: object, pad=False, pad_tok=0., backwards=False):
        self.pad = pad
        self.pad_tok = pad_tok
        self.tokenizer = tokenizer
        self.backwards = backwards
        self.alphabet = alphabet

    def __call__(self, batch: List[Any], ) -> List[torch.Tensor]:
        data = tuple(zip(*batch))
        sequences = data[0]
        sequences = [torch.tensor(self.tokenizer.tokenize(s)) for s in sequences]
        sequences = [i.view(-1,1) for i in sequences]
        maxlen = max([i.shape[0] for i in sequences])
        padded = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", self.pad_tok) for i in sequences]
        padded = torch.stack(padded)
        mask = [torch.ones(i.shape[0]) for i in sequences]
        mask = [F.pad(i, (0, maxlen - i.shape[0])) for i in mask]
        mask = torch.stack(mask)
        y = data[1]
        y = torch.tensor(y).unsqueeze(-1)
        ohe = []
        for i in padded:
            i_onehot = torch.FloatTensor(maxlen, len(self.alphabet))
            i_onehot.zero_()
            i_onehot.scatter_(1, i, 1)
            ohe.append(i_onehot)
        padded = torch.stack(ohe)
            
        return padded, y, mask


class LengthMaxPool1D(nn.Module):
    def __init__(self, linear=False, in_dim=1024, out_dim=2048):
        super().__init__()
        self.linear = linear
        if self.linear:
            self.layer = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        if self.linear:
            x = F.relu(self.layer(x))
        x = torch.max(x, dim=1)[0]
        return x


class FluorescenceModel(nn.Module):
    def __init__(self, n_tokens):
        super(FluorescenceModel, self).__init__()
        self.encoder = MaskedConv1d(n_tokens, 1024, kernel_size=5)
        self.embedding = LengthMaxPool1D(linear=True, in_dim=1024, out_dim=2048)
        self.decoder = nn.Linear(2048, 1)
        self.n_tokens = n_tokens

    def forward(self, x, mask):
        # encoder
        x = F.relu(self.encoder(x, input_mask=mask.repeat(self.n_tokens, 1, 1).permute(1, 2, 0)))
        x = x * mask.repeat(1024, 1, 1).permute(1, 2, 0)
        # embed
        x = self.embedding(x)
        # decoder
        output = self.decoder(x)
        return output


def train(args):
    # set up training environment
    torch.manual_seed(args.seed)

    batch_size = 256
    epochs = 1000
    device = torch.device('cuda:%d' %args.gpu)
    alphabet = AAINDEX_ALPHABET
    if 'miser' in args.data_fpath:
        alphabet += '-'
    tokenizer = Tokenizer(alphabet)
    print('USING OHE HOT ENCODING')
    if args.bin:
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.MSELoss()
    model = FluorescenceModel(len(alphabet))
    model = model.to(device)
    optimizer = optim.Adam([
        {'params': model.encoder.parameters(), 'lr': 1e-3, 'weight_decay': 0},
        {'params': model.embedding.parameters(), 'lr': 5e-5, 'weight_decay': 0.05},
        {'params': model.decoder.parameters(), 'lr': 5e-6, 'weight_decay': 0.05}
    ])
    patience = 20
    p = 0
    best_rho = -1

    # grab data
    data_fpath = args.data_fpath
    df = pd.read_csv(data_fpath)
    if 'miser' in data_fpath:
        df['sequence'] = df['gapped']
        if not args.low_n:
            df['split'] = df['split2']
    elif 'cas9' in data_fpath:
        df['tgt'] = df[args.col]
        df = df.dropna(subset=['tgt'])
    ds_valid = CSVDataset(df=df, split='valid', outputs=['tgt'])
    df_test = df[df['split'] == 'test'].reset_index()
    idx = np.arange(len(df_test))
    np.random.seed(0)
    np.random.shuffle(idx)
    ds_test = CSVDataset(df=df_test.loc[idx[:5000]], split='test', outputs=['tgt'])
    # setup dataloaders
    dl_valid_AA = DataLoader(ds_valid, collate_fn=ASCollater(alphabet, tokenizer, pad=True, pad_tok=0.),
                             batch_size=batch_size, num_workers=16)
    dl_test_AA = DataLoader(ds_test, collate_fn=ASCollater(alphabet, tokenizer, pad=True, pad_tok=0.),
                            batch_size=batch_size * 2, num_workers=16)

    def step(model, batch, train=True):
        src, tgt, mask = batch
        src = src.to(device).float()
        tgt = tgt.to(device).float()
        mask = mask.to(device).float()
        output = model(src, mask)
        loss = criterion(output, tgt)
        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        return loss.item(), output.detach().cpu(), tgt.detach().cpu()

    def epoch(model, train, current_step=0, return_values=False):
        start_time = datetime.now()
        if train:
            model = model.train()
            loader = dl_train_AA
            t = 'Training'
            n_total = len(ds_train)
        else:
            model = model.eval()
            loader = dl_valid_AA
            t = 'Validating'
            n_total = len(ds_valid)
        losses = []
        outputs = []
        tgts = []
        chunk_time = datetime.now()
        n_seen = 0
        for i, batch in enumerate(loader):
            loss, output, tgt = step(model, batch, train)
            losses.append(loss)
            outputs.append(output)
            tgts.append(tgt)
            n_seen += len(batch[0])
            if train:
                nsteps = current_step + i + 1
            else:
                nsteps = i
            # print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %.4f'
            #       % (t, e + 1, epochs, nsteps, n_seen, n_total, np.mean(np.array(losses)),),
            #       end='')
            
        outputs = torch.cat(outputs).numpy()
        tgts = torch.cat(tgts).cpu().numpy()
        if train:
            # print('\nTraining complete in ' + str(datetime.now() - chunk_time))
            if not args.low_n:
                with torch.no_grad():
                    _, val_rho = epoch(model, False, current_step=nsteps)
            else:
                val_rho = 0
        if not train:
            # print('\nValidation complete in ' + str(datetime.now() - start_time))
            if args.bin:
                val_rho = roc_auc_score(tgts, outputs)
            else:
                val_rho = spearmanr(tgts, outputs).correlation
        if return_values:
            return i + 1, val_rho, tgts, outputs
        else:
            return i + 1, val_rho

    if args.low_n:
        reps = 16
        if not os.path.exists(args.out_fpath):
            with open(args.out_fpath, 'w') as f:
                f.write('n_train,mse,rho\n')
        bs = np.array([16, 32, 64, 128, 256, 512, 1024])
        np.random.seed(32)
        max_steps = 128
        st = np.minimum(4 * bs, max_steps)
        df_train = df[df['split'] == 'train'].reset_index()
        idx = np.arange(len(df_train))
        test_steps = len(ds_test) // (batch_size * 4)
        np.random.seed(32)
        with tqdm(total=reps * (st.sum() + test_steps * len(bs))) as pbar:
            for rep in range(reps):
                np.random.shuffle(idx)
                for b in bs:
                    batch_size = int(b // 4)
                    for name, module in model.named_modules():
                        if hasattr(module, 'reset_parameters'):
                            module.reset_parameters()
                    np.random.shuffle(idx)
                    df_ = df_train.loc[idx[:b]]
                    ds_train = CSVDataset(df=df_, split='train', outputs=['tgt'])
                    dl_train_AA = DataLoader(ds_train, collate_fn=ASCollater(alphabet, tokenizer, pad=True, pad_tok=0.),
                                             batch_size=batch_size, shuffle=True, num_workers=16)
                    nsteps = 0
                    epochs = b
                    for e in range(epochs):
                        s, _ = epoch(model, True, current_step=nsteps)
                        nsteps += s
                        pbar.update(s)
                        if nsteps > max_steps:
                            break
                    dl_valid_AA = dl_test_AA
                    _, val_rho, tgt, pre = epoch(model, False, current_step=nsteps, return_values=True)
                    mse = ((tgt - pre) ** 2).mean()
                    with open(args.out_fpath, 'a') as f:
                        f.write('%d,%f,%f\n' %(b, mse, val_rho))
                    pbar.update(test_steps)
    else:
        ds_train = CSVDataset(df=df, split='train', outputs=['tgt'])
        # setup dataloaders
        dl_train_AA = DataLoader(ds_train, collate_fn=ASCollater(alphabet, tokenizer, pad=True, pad_tok=0.),
                                 batch_size=batch_size, shuffle=True, num_workers=4)
        nsteps = 0
        e = 0
        for e in range(epochs):
            s, val_rho = epoch(model, True, current_step=nsteps)
            print(val_rho)
            nsteps += s
            if (e%10 == 0) or (e == epochs-1):
                torch.save({
                    'step': nsteps,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, args.out_fpath + 'checkpoint%d.tar' % nsteps)
            if val_rho > best_rho:
                p = 0
                best_rho = val_rho
                torch.save({
                    'step': nsteps,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, args.out_fpath + 'bestmodel.tar')
            else:
                p += 1
            if p == patience:
                print('MET PATIENCE')
                break
        print('Testing...')
        sd = torch.load(args.out_fpath + 'bestmodel.tar')
        model.load_state_dict(sd['model_state_dict'])
        dl_valid_AA = dl_test_AA
        _, val_rho, tgt, pre = epoch(model, False, current_step=nsteps, return_values=True)
        print('rho = %.4f' %val_rho)
        val_mse = ((tgt - pre) ** 2).mean()
        print('mse = %.4f' %val_mse)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('data_fpath', type=str)
    parser.add_argument('out_fpath', type=str, help='save directory')
    parser.add_argument('--bin', action='store_true')
    parser.add_argument('--low_n', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--col')
    parser.add_argument('--seed', default=0, type=int)
    args = parser.parse_args()
    train(args)

if __name__ == '__main__':
    main()