"""
PARTIALLY COPY FROM https://colab.research.google.com/drive/1wyVEQd4R3HYLTUOXEEQmp_I8aNC_aLhL
"""
import dataclasses
import torch.optim as optim
import torch
from torch.utils.data import DataLoader

import sys
from datetime import datetime
from time import time
import subprocess
import pickle

from lib.data_pipeline import *
from lib.model import *
from lib.settings import *

### CONFIGURATION ###
@dataclasses.dataclass
class Configuration:
    num_epochs: int = 100  # ~100 seems fine
    use_padding: bool = True
    dataset_name: str = 'human_enhancers_cohn'
    batch_size: int = 256
    learning_rate: float = 6e-4  # good default for Hyena
    rc_aug: bool = True  # reverse complement augmentation
    add_eos: bool = False  # add end of sentence token
    weight_decay: float = 0.1
    pretrained_model_name: str = 'hyenadna-tiny-1k-seqlen'
    use_head: bool = True
    backbone_cfg = None

    def print(self):
        print("===== Configuration =====")
        for k, v in self.__dict__.items():
            if "__name__" in dir(v):
                print(f"{k}: {v.__name__}")
            else:
                print(f"{k}: {v}")
        sys.stdout.flush()

def train(model, device, train_loader, optimizer, epoch, loss_fn, log_interval=10):
    """Training loop."""
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target.squeeze())
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()), flush=True)

def test(model, device, test_loader, loss_fn):
    """Test loop."""
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += loss_fn(output, target.squeeze()).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

if __name__ == "__main__":
    start_time = time()
    conf = Configuration()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    n_classes = get_n_class(conf.dataset_name)
    max_length = 500
    savedirname = sys.argv[1] if len(sys.argv) > 1 else "result"
    
    # save configuration
    print("Using device:", device)
    conf.print()
    with open(f"{savedirname}/config.pkl", "wb") as f:
        pickle.dump(conf, f)

    model = HyenaDNAPreTrainedModel.from_pretrained(
        './checkpoints',
        conf.pretrained_model_name,
        download=False, # download=True,
        config=conf.backbone_cfg,
        device=device,
        use_head=conf.use_head,
        n_classes=n_classes,
    )

    # create tokenizer
    tokenizer = CharacterTokenizer(
        characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
        model_max_length=max_length + 2,  # to account for special tokens, like EOS
        add_special_tokens=False,  # we handle special tokens elsewhere
        padding_side='left', # since HyenaDNA is causal, we pad on the left
    )

    # create datasets
    ds_train = GenomicBenchmarkDataset(
        max_length = max_length,
        use_padding = conf.use_padding,
        split = 'train',
        tokenizer=tokenizer,
        dataset_name=conf.dataset_name,
        rc_aug=conf.rc_aug,
        add_eos=conf.add_eos,
    )

    ds_test = GenomicBenchmarkDataset(
        max_length = max_length,
        use_padding = conf.use_padding,
        split = 'test',
        tokenizer=tokenizer,
        dataset_name=conf.dataset_name,
        rc_aug=conf.rc_aug,
        add_eos=conf.add_eos,
    )

    train_loader = DataLoader(ds_train, batch_size=conf.batch_size, shuffle=True)
    test_loader = DataLoader(ds_test, batch_size=conf.batch_size, shuffle=False)

    # loss function
    loss_fn = nn.CrossEntropyLoss()

    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=conf.learning_rate, weight_decay=conf.weight_decay)
    model.to(device)

    # train
    for epoch in range(conf.num_epochs):
        train(model, device, train_loader, optimizer, epoch, loss_fn)
        test(model, device, test_loader, loss_fn)
        optimizer.step()

    # save model
    torch.save(model.state_dict(), f"{savedirname}/model.pth")
    # output time
    elapsed_time = int(time() - start_time)
    print("Elapsed time:", f"{str(elapsed_time//3600).zfill(2)}:{str(elapsed_time//60).zfill(2)}:{str(elapsed_time%60).zfill(2)}")
    