import torch
import numpy as np
import os
import datetime
import shutil
import toml
import torch.nn.functional as F
import torch.distributed as dist
import utils
import torch.backends.cudnn as cudnn
from time import time
from torch.utils.data import DataLoader
from m6araw_utils.dataset import NanoporeDS, collate_fn
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
from sklearn.metrics import balanced_accuracy_score, roc_curve, auc, precision_recall_curve


def get_args_parser():
    parser = ArgumentParser('train_segmentation',
        formatter_class=ArgumentDefaultsHelpFormatter,
        add_help=False
    )
    parser.add_argument("--train_dir", default=None, required=True)
    parser.add_argument("--val_dir", default=None, required=True)
    parser.add_argument("--save_dir", default=None, required=True)

    # Model setting
    parser.add_argument("--model_config", default= "/home/christopher/bonito_models/dna_r10.4_e8.1_fast@v3.4/")
    parser.add_argument("--feature_dim", default=128, type=int)
    parser.add_argument("--kmer_embedding_dim", default=4, type=int)
    parser.add_argument("--num_signal_lstm_layers", default=3, type=int)
    parser.add_argument("--lstm_hidden_size", default=64, type=int)
    parser.add_argument("--lstm_num_layers", default=5, type=int)
    parser.add_argument("--pretrained_dir", default=None, type=str)
    parser.add_argument("--test_run", default=False, action="store_true")
    parser.add_argument("--include_signal_head", default=False, action="store_true")
    parser.add_argument("--train_max_reads", default=20, type=int)
    parser.add_argument("--sample_reads", default=0, type=int)
    parser.add_argument("--fine_tune_feature_extractor", default=False, action="store_true")

    # Training settings
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--min_lr", default=1e-6, type=float)
    parser.add_argument("--lr", default=4e-4, type=float)
    parser.add_argument("--warmup_epochs", default=0, type=int)
    parser.add_argument("--apply_scheduler", default=False, action="store_true")
    parser.add_argument("--seed", default=25, type=int)
    parser.add_argument("--sequence_context", default=10, type=int)
    parser.add_argument("--min_reads", default=20, type=int)
    parser.add_argument("--max_reads", default=100, type=int)
    parser.add_argument("--epochs", default=20, type=int)
    parser.add_argument("--num_workers", default=25, type=int)
    parser.add_argument("--optim", default='AdamW', type=str, choices=['AdamW', 'SGD'])
    parser.add_argument("--weight_decay", dest="weight_decay", default=0, type=float)
    parser.add_argument("--clear_save_dir", default=False, action="store_true")
    parser.add_argument("--mixed_precision", default=False, action="store_true")
    parser.add_argument("--resume", default=False, action="store_true")
    parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""")
    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
    return parser


def get_roc_auc(y_true, y_pred):
    fpr, tpr, _  = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    return roc_auc


def get_pr_auc(y_true, y_pred):
    precision, recall, _ = precision_recall_curve(y_true, y_pred, pos_label=1)
    pr_auc = auc(recall, precision)
    return pr_auc


def get_accuracy(y_true, y_pred):
    return balanced_accuracy_score(y_true, y_pred)


def print_clf_metrics(train_stats, val_stats, curr_epoch, total_epochs):
                      

    print("Epoch:[{epoch}/{n_epoch}]".format(epoch=curr_epoch, n_epoch=total_epochs))

    train_results = ""
    for metric, result in train_stats.items():
        train_results += "{}:{:.3f} \t ".format(metric, result)
    
    print(train_results)

    val_results = ""
    for metric, result in val_stats.items():
        val_results += "{}:{:.3f} \t ".format(metric, result)
    
    print(val_results)
    
    print("=====================================")


def get_model(args):
    
    from bonito_model import BaseCaller
    from m6a_utils_short_seq.model import Model

    model_config = toml.load(os.path.join(args.model_config, "config.toml"))
    feature_extractor = BaseCaller(model_config)

    if args.pretrained_dir is not None:

        model_weights = torch.load(args.pretrained_dir,
                                   map_location='cpu')
        feature_extractor.load_state_dict(model_weights)
        feature_extractor.encoder = nn.Sequential(*[feature_extractor.encoder[i] for i in range(len(feature_extractor.encoder) - 1)])
    
    seq_dim = (2 * args.sequence_context + 1) * args.lstm_hidden_size * 2
    # seq_dim = ((2 * args.sequence_context + 1) - 5) * args.lstm_hidden_size * 2

    lstm_params = {'input_size': args.kmer_embedding_dim, 
                   'hidden_size': args.lstm_hidden_size, 'num_layers': args.lstm_num_layers,
                   'batch_first': True, 'bidirectional': True}

    model = Model(feature_extractor, args.feature_dim, args.kmer_embedding_dim, lstm_params, seq_dim, 
                  args.sample_reads, args.fine_tune_feature_extractor)
    return model


def random_fn(x):
    np.random.seed(datetime.datetime.now().second)


def validate(model, val_dl, criterion, fp16_scaler, args):
    model.eval()
    start = time()

    all_labels = []
    all_preds = []
    val_loss = []

    with torch.no_grad():
        for _, data in enumerate(val_dl):
            signals, signals_lengths, labels, positions_masks, sequences, sequences_length = data
            positions_masks, sequences, labels = [pos_mask.cuda() for pos_mask in positions_masks], sequences.cuda(), labels.cuda()

            with autocast(fp16_scaler is not None):
                outputs = model(signals, signals_lengths, positions_masks, sequences, sequences_length, mode='eval')
                val_loss.append(criterion(outputs, labels).detach().cpu().item())
                all_preds.append(F.softmax(outputs, dim=1).detach().cpu().numpy()[:, 1])
                all_labels.append(labels.detach().cpu().numpy())


            if args.test_run: # test run over one mini batch
                break 

    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)
    
    compute_time = torch.Tensor([time() - start]).cuda()
    val_loss = torch.tensor([np.mean(val_loss)]).cuda()
    val_roc_auc = torch.Tensor([get_roc_auc(all_labels, all_preds)]).cuda()
    val_pr_auc = torch.Tensor([get_pr_auc(all_labels, all_preds)]).cuda()
    val_acc = torch.Tensor([get_accuracy(all_labels, (all_preds >= 0.5) * 1)]).cuda()


    # Gathering training metrics from all processes
    dist.all_reduce(compute_time, op=dist.ReduceOp.MAX)
    dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)
    dist.all_reduce(val_roc_auc, op=dist.ReduceOp.SUM)
    dist.all_reduce(val_pr_auc, op=dist.ReduceOp.SUM)
    dist.all_reduce(val_acc, op=dist.ReduceOp.SUM)
    
    if utils.is_main_process():

        compute_time = compute_time.item()

        val_loss = val_loss.detach().cpu().item() / utils.get_world_size()
        val_roc_auc = val_roc_auc.detach().cpu().item() / utils.get_world_size()  
        val_pr_auc = val_pr_auc.detach().cpu().item() / utils.get_world_size()
        val_acc = val_acc.detach().cpu().item() / utils.get_world_size()

        return {'Val Loss': val_loss,
                'Val ROC AUC': val_roc_auc,
                'Val PR AUC': val_pr_auc, 
                'Val Accuracy': val_acc,
                'Val Time': compute_time}


def train_classification_one_epoch(model, train_dl, criterion,
                                   optimizer, lr_schedule, curr_epoch, fp16_scaler, args):

    all_train_loss = []
    start = time()
    running_loss = 0 
    model.train()

    for it, data in enumerate(train_dl):
        train_dl.sampler.set_epoch(curr_epoch)
        it = len(train_dl) * curr_epoch + it  # global training iteration
        if lr_schedule is not None: # Adjust learning rate to the schedule
            for _, param_group in enumerate(optimizer.param_groups):
                param_group["lr"] = lr_schedule[it]
        
        signals, signals_lengths, labels, positions_masks, sequences, sequences_length = data
        positions_masks, sequences, labels = [pos_mask.cuda() for pos_mask in positions_masks], sequences.cuda(), labels.cuda()
        with autocast(fp16_scaler is not None):
            outputs = model(signals, signals_lengths, positions_masks, sequences, sequences_length, mode='train')
            loss = criterion(outputs, labels)
            if fp16_scaler is None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0).item()
                optimizer.step()
                optimizer.zero_grad()
            else:
                fp16_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0).item()
                fp16_scaler.step(optimizer)
                fp16_scaler.update()

        running_loss = utils.update_metric(running_loss, loss.detach().cpu().item(), smoothing_factor=0.99)
        all_train_loss.append(running_loss)
        if args.test_run: #test run over one minibatch
            break 

    train_time = torch.Tensor([time() - start]).cuda()
    all_train_loss = torch.Tensor(all_train_loss).cuda()

    # Gathering training metrics from all processes

    dist.all_reduce(train_time, op=dist.ReduceOp.MAX)
    dist.all_reduce(all_train_loss, op=dist.ReduceOp.SUM)


    if utils.is_main_process():

        train_time = train_time.item()
        return {'Train Loss': all_train_loss.detach().cpu().numpy() / utils.get_world_size(), 'Train Time': train_time}


def train(model, train_dl, val_dl, criterion,
          optimizer, lr_schedule, total_epochs, increment, model_save_dir, fp16_scaler, args):

    all_train_loss = []
    all_val_loss = []

    start_epoch = 0 + increment
    end_epoch = total_epochs
    
    for epoch in range(start_epoch, end_epoch):  # loop over the dataset multiple times
        
        train_stats = train_classification_one_epoch(model, train_dl, criterion, optimizer, lr_schedule,
                                                     epoch, fp16_scaler, args)
        val_stats = validate(model, val_dl, criterion, fp16_scaler, args)

        if utils.is_main_process():
            train_losses = train_stats['Train Loss']
            val_loss = val_stats['Val Loss']
            train_stats['Train Loss'] = train_losses[-1]
            print_clf_metrics(train_stats, val_stats, epoch, end_epoch)
            check_point = {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'Val Loss': val_loss}
            check_point.update(train_stats)
            check_point.update(val_stats)
            save_path = os.path.join(model_save_dir, str(epoch))
            os.makedirs(save_path)

            if not args.test_run:
                torch.save(check_point, os.path.join(save_path, "checkpoint.pth"))

            all_train_loss.append(train_losses)
            all_val_loss.append(val_loss)

        if args.test_run: # test run over one epoch
            break

    if utils.is_main_process():
        return np.concatenate(all_train_loss), np.array(all_val_loss)


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    
    # Train model with gradual warmup for few epochs, increase lr linearly until reaches base value
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep 
    
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    # Vary learning rate for the remaining iteration, periodically moving from base value to final value

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


def get_weights(counts):
    weights = 1./torch.tensor(counts, dtype=torch.float)
    weights = weights / weights.sum()
    weights = torch.tensor(weights.numpy().tolist())
    return weights


def train_and_save(args):

    save_dir = args.save_dir
    weight_decay = args.weight_decay

    lr = args.lr
    min_lr = args.min_lr
    warmup_epochs = args.warmup_epochs
    apply_scheduler = args.apply_scheduler

    mixed_precision = args.mixed_precision
    resume = args.resume

    utils.init_distributed_mode(args)
    utils.fix_random_seeds(args.seed)
    cudnn.benchmark = True
    torch.cuda.set_device(args.rank)

    batch_size_per_gpu = utils.get_batch_size(args.batch_size)

    model_save_dir = os.path.join(save_dir, "model_states")
    model = get_model(args).cuda()
    if utils.has_batchnorms(model):
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DDP(model, device_ids=[args.gpu])    

    # Loading transcripts information

    train_ds = NanoporeDS(args.train_dir, mode='Train', min_reads=args.min_reads, max_reads=args.train_max_reads, sequence_context=args.sequence_context)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
    	train_ds, shuffle=True
    )
    train_dl = DataLoader(train_ds, batch_size=batch_size_per_gpu, num_workers=args.num_workers, pin_memory=True,
                          worker_init_fn=random_fn, sampler=train_sampler, collate_fn=collate_fn)

    val_ds = NanoporeDS(args.val_dir, mode='Val', min_reads=args.min_reads, max_reads=args.train_max_reads, sequence_context=args.sequence_context)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
    	val_ds, shuffle=False
    )
    val_dl = DataLoader(val_ds, batch_size=batch_size_per_gpu, num_workers=args.num_workers, 
                        sampler=val_sampler, pin_memory=True, collate_fn=collate_fn)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    _, counts = np.unique(train_ds.data_info.modification_status.values, return_counts=True)
    # _, counts = np.unique([x for y in list(train_ds.labels.values()) for x in y], return_counts=True)
    class_weights = get_weights(counts)
    criterion = nn.CrossEntropyLoss(weight=class_weights).cuda()

    if args.optim == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(), weight_decay=weight_decay, lr=lr)
    elif args.optim == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), weight_decay=weight_decay, lr=lr, momentum=0.9, nesterov=True)
    elif args.optim == 'RMSprop':
        optimizer = torch.optim.RMSprop(model.parameters(), weight_decay=weight_decay, lr=lr)
    
    lr_schedule = None if not apply_scheduler else cosine_scheduler(lr, min_lr, args.epochs, len(train_dl), warmup_epochs)
    increment = 0
    if resume:
        last_epoch = np.sort([int(x) for x in os.listdir(os.path.join(save_dir, "model_states"))])[-1]
        check_point = torch.load(os.path.join(save_dir, "model_states", str(last_epoch), "checkpoint.pth"), 
                                 map_location=torch.device(args.local_rank))
        print("Resuming from epoch {}".format(last_epoch))
        model.load_state_dict(check_point['model'])
        optimizer.load_state_dict(check_point['optimizer'])
        increment = last_epoch + 1
    
    if mixed_precision:
        fp16_scaler = GradScaler()
    else:
        fp16_scaler = None
    
    if utils.is_main_process():
        train_loss, val_loss = train(model, train_dl, val_dl, criterion,
                                     optimizer, lr_schedule,
                                     args.epochs, increment, model_save_dir, fp16_scaler, args)

        np.save(os.path.join(save_dir, "train_loss.npy"), train_loss)
        np.save(os.path.join(save_dir, "val_loss.npy"), val_loss)
    else:
        train(model, train_dl, val_dl, criterion,
              optimizer, lr_schedule, args.epochs, increment, model_save_dir, fp16_scaler, args)

    
if __name__ == '__main__':
    parser = ArgumentParser('Train m6a detection', parents=[get_args_parser()])
    args = parser.parse_args()
    
    save_dir = args.save_dir
    if args.clear_save_dir:
        shutil.rmtree(save_dir, ignore_errors=True)

    Path(args.save_dir).mkdir(parents=True, exist_ok=True)
    train_and_save(args)
