import torch
import numpy as np
import os
import datetime
import shutil
import toml
import torch.nn.functional as F
import pandas as pd
import torch.distributed as dist
import utils
import torch.backends.cudnn as cudnn
import parasail
from criterion import CTCLoss
from time import time
from torch.utils.data import DataLoader, TensorDataset
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
from basecalling_utils import accuracy, decode_ref


def get_args_parser():
    parser = ArgumentParser('train_segmentation',
        formatter_class=ArgumentDefaultsHelpFormatter,
        add_help=False
    )
    parser.add_argument("--root_dir", default=None, required=True)
    parser.add_argument("--save_dir", default=None, required=True)

    # Model setting
    parser.add_argument("--model_config", default= "/home/christopher/bonito_models/dna_r10.4_e8.1_fast@v3.4/")
    parser.add_argument("--kmer_embedding_dim", default=4, type=int)
    parser.add_argument("--lstm_hidden_size", default=64, type=int)
    parser.add_argument("--lstm_num_layers", default=5, type=int)
    parser.add_argument("--attention_heads", default=1, type=int)
    parser.add_argument("--kdim", default=320, type=int)
    parser.add_argument("--vdim", default=320, type=int)
    parser.add_argument("--pretrained", default=False, action="store_true")
    parser.add_argument("--test_run", default=False, action="store_true")

    # Training settings
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--min_lr", default=1e-6, type=float)
    parser.add_argument("--lr", default=4e-4, type=float)
    parser.add_argument("--warmup_epochs", default=0, type=int)
    parser.add_argument("--apply_scheduler", default=False, action="store_true")
    parser.add_argument("--seed", default=25, type=int)
    parser.add_argument("--epochs", default=20, type=int)
    parser.add_argument("--num_workers", default=25, type=int)
    parser.add_argument("--optim", default='AdamW', type=str, choices=['AdamW', 'SGD'])
    parser.add_argument("--weight_decay", dest="weight_decay", default=0, type=float)
    parser.add_argument("--clear_save_dir", default=False, action="store_true")
    parser.add_argument("--mixed_precision", default=False, action="store_true")
    parser.add_argument("--resume", default=False, action="store_true")
    parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""")
    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
    return parser


def readFasta(transcript_fasta_paths_or_urls):
    fasta=open(transcript_fasta_paths_or_urls,"r")
    entries=""
    for ln in fasta:
        entries+=ln
    entries=entries.split(">")
    dict={}
    for entry in entries:
        entry=entry.split("\n")
        if len(entry[0].split()) > 0:
            id=entry[0].split(".")[0]
            seq="".join(entry[1:])
            dict[id]=seq
    return dict


def print_clf_metrics(train_stats, val_stats, curr_epoch, total_epochs):
                      

    print("Epoch:[{epoch}/{n_epoch}]".format(epoch=curr_epoch, n_epoch=total_epochs))

    train_results = ""
    for metric, result in train_stats.items():
        train_results += "{}:{:.3f} \t ".format(metric, result)
    
    print(train_results)

    val_results = ""
    for metric, result in val_stats.items():
        val_results += "{}:{:.3f} \t ".format(metric, result)
    
    print(val_results)
    
    print("=====================================")


def get_model(args):
    
    from bonito_model import Model, BaseCaller
    
    model_config = toml.load(os.path.join(args.model_config, "config.toml"))
    basecaller = BaseCaller(model_config)

    if args.pretrained:

        model_weights = torch.load(os.path.join(args.model_config, "weights_1.tar"),
                                   map_location='cpu')
        basecaller.load_state_dict(model_weights)

    return basecaller


def random_fn(x):
    np.random.seed(datetime.datetime.now().second)


def validate(model, val_dl, criterion, fp16_scaler, args):
    model.eval()
    start = time()

    val_loss = []
    val_accs = []
    with torch.no_grad():
        for _, data in enumerate(val_dl):
            signals, targets, sequences_length = data
            signals, targets, sequences_length = signals.unsqueeze(1).cuda(), \
                    targets.cuda(), sequences_length.cuda()

            with autocast(fp16_scaler is not None):
                outputs = model(signals) # S x N x C
                loss = criterion(outputs, targets, sequences_length).detach().cpu().item()
                seqs = model(outputs, mode='decode')

            # if hasattr(model, 'decode_batch'):
            # else:
            #     seqs = [model.decode(x) for x in permute(outputs, 'TNC', 'NTC')]

            refs = [decode_ref(target, model.module.alphabet) for target in targets]
            accs = [
                accuracy(ref, seq, min_coverage=0.5) if len(seq) else 0. for ref, seq in zip(refs, seqs)
            ]

            if args.test_run: # test run over one mini batch
                break 
            val_loss.append(loss)
            val_accs.extend(accs)

    loss = torch.tensor([np.mean(val_loss)]).cuda()
    mean_acc = torch.tensor([np.mean(val_accs)]).cuda()
    median_acc = torch.tensor([np.median(val_accs)]).cuda()

    compute_time = torch.Tensor([time() - start]).cuda()

    # Gathering training metrics from all processes
    dist.all_reduce(compute_time, op=dist.ReduceOp.MAX)
    dist.all_reduce(loss, op=dist.ReduceOp.SUM)
    dist.all_reduce(mean_acc, op=dist.ReduceOp.SUM)
    dist.all_reduce(median_acc, op=dist.ReduceOp.SUM)
    
    if utils.is_main_process():

        compute_time = compute_time.item()

        loss = loss.detach().cpu().item() / utils.get_world_size()
        mean_acc = mean_acc.detach().cpu().item() / utils.get_world_size()
        median_acc = median_acc.detach().cpu().item() / utils.get_world_size()

        return {'Val Loss': loss,
                'Val Acc': mean_acc,
                'Val Median Acc': median_acc,
                'Val Time': compute_time}


def train_classification_one_epoch(model, train_dl, criterion,
                                   optimizer, lr_schedule, curr_epoch, fp16_scaler, args):
    all_train_segments = []
    all_train_preds = []
    all_train_loss = []

    start = time()
    running_loss = 0 

    model.train()
    for it, data in enumerate(train_dl):
        train_dl.sampler.set_epoch(curr_epoch)
        it = len(train_dl) * curr_epoch + it  # global training iteration
        if lr_schedule is not None: # Adjust learning rate to the schedule
            for _, param_group in enumerate(optimizer.param_groups):
                param_group["lr"] = lr_schedule[it]

        signals, sequences, sequences_length = data
        signals, sequences, sequences_length = signals.unsqueeze(1).cuda(), \
                sequences.cuda(), sequences_length.cuda()
        with autocast(fp16_scaler is not None):
            outputs = model(signals)
            loss = criterion(outputs, sequences, sequences_length)
            if fp16_scaler is None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0).item()
                optimizer.step()
                optimizer.zero_grad()
            else:
                fp16_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0).item()
                fp16_scaler.step(optimizer)
                fp16_scaler.update()

        running_loss = utils.update_metric(running_loss, loss.detach().cpu().item(), smoothing_factor=0.99)
        all_train_loss.append(running_loss)
        if args.test_run: #test run over one minibatch
            break 
    assert(len(all_train_segments) == len(all_train_preds))


    train_time = torch.Tensor([time() - start]).cuda()
    all_train_loss = torch.Tensor(all_train_loss).cuda()

    # Gathering training metrics from all processes

    dist.all_reduce(train_time, op=dist.ReduceOp.MAX)
    dist.all_reduce(all_train_loss, op=dist.ReduceOp.SUM)


    if utils.is_main_process():

        train_time = train_time.item()
        return {'Train Loss': all_train_loss.detach().cpu().numpy() / utils.get_world_size(), 'Train Time': train_time}


def train(model, train_dl, val_dl, criterion,
          optimizer, lr_schedule, total_epochs, increment, model_save_dir, fp16_scaler, args):

    all_train_loss = []
    all_val_loss = []

    start_epoch = 0 + increment
    end_epoch = total_epochs
    
    for epoch in range(start_epoch, end_epoch):  # loop over the dataset multiple times
        
        train_stats = train_classification_one_epoch(model, train_dl, criterion, optimizer, lr_schedule,
                                                     epoch, fp16_scaler, args)
        val_stats = validate(model, val_dl, criterion, fp16_scaler, args)

        if utils.is_main_process():
            train_losses = train_stats['Train Loss']
            val_loss = val_stats['Val Loss']
            train_stats['Train Loss'] = train_losses[-1]
            print_clf_metrics(train_stats, val_stats, epoch, end_epoch)
            check_point = {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'Val Loss': val_loss}
            check_point.update(train_stats)
            check_point.update(val_stats)
            save_path = os.path.join(model_save_dir, str(epoch))
            os.makedirs(save_path)
            torch.save(check_point, os.path.join(save_path, "checkpoint.pth"))

            all_train_loss.append(train_losses)
            all_val_loss.append(val_loss)

        if args.test_run: # test run over one epoch
            break

    if utils.is_main_process():
        return np.concatenate(all_train_loss), np.array(all_val_loss)


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    
    # Train model with gradual warmup for few epochs, increase lr linearly until reaches base value
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep 
    
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    # Vary learning rate for the remaining iteration, periodically moving from base value to final value

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


def train_and_save(args):

    save_dir = args.save_dir
    weight_decay = args.weight_decay

    lr = args.lr
    min_lr = args.min_lr
    warmup_epochs = args.warmup_epochs
    apply_scheduler = args.apply_scheduler

    mixed_precision = args.mixed_precision
    resume = args.resume

    utils.init_distributed_mode(args)
    utils.fix_random_seeds(args.seed)
    cudnn.benchmark = True
    torch.cuda.set_device(args.rank)

    batch_size_per_gpu = utils.get_batch_size(args.batch_size)

    # Loading transcripts information

    root_dir = args.root_dir

    train_signals = torch.Tensor(np.load(os.path.join(root_dir, "chunks.npy")))
    train_seqs = torch.LongTensor(np.load(os.path.join(root_dir, "references.npy")).astype('int64'))
    train_lengths = torch.LongTensor(np.load(os.path.join(root_dir, "reference_lengths.npy")).astype('int64'))

    val_signals = torch.Tensor(np.load(os.path.join(root_dir, "validation", "chunks.npy")))
    val_seqs = torch.LongTensor(np.load(os.path.join(root_dir, "validation", "references.npy")).astype('int64'))
    val_lengths = torch.LongTensor(np.load(os.path.join(root_dir, "validation", "reference_lengths.npy")).astype('int64'))

    train_ds = TensorDataset(train_signals, train_seqs, train_lengths)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
    	train_ds, shuffle=True
    )
    train_dl = DataLoader(train_ds, batch_size=batch_size_per_gpu, num_workers=args.num_workers, pin_memory=True,
                          worker_init_fn=random_fn, sampler=train_sampler)

    val_ds = TensorDataset(val_signals, val_seqs, val_lengths)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
    	val_ds, shuffle=False
    )
    val_dl = DataLoader(val_ds, batch_size=batch_size_per_gpu, num_workers=args.num_workers, 
                        sampler=val_sampler, pin_memory=True)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model_save_dir = os.path.join(save_dir, "model_states")
    model = get_model(args).cuda()
    criterion = CTCLoss(model.seqdist.idx.clone()).cuda()
    if utils.has_batchnorms(model):
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DDP(model, device_ids=[args.gpu])    

    if args.optim == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(), weight_decay=weight_decay, lr=lr)
    elif args.optim == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), weight_decay=weight_decay, lr=lr, momentum=0.9, nesterov=True)
    elif args.optim == 'RMSprop':
        optimizer = torch.optim.RMSprop(model.parameters(), weight_decay=weight_decay, lr=lr)
    
    lr_schedule = None if not apply_scheduler else cosine_scheduler(lr, min_lr, args.epochs, len(train_dl), warmup_epochs)
    increment = 0
    if resume:
        last_epoch = np.sort([int(x) for x in os.listdir(os.path.join(save_dir, "model_states"))])[-1]
        check_point = torch.load(os.path.join(save_dir, "model_states", str(last_epoch), "checkpoint.pth"), 
                                 map_location=torch.device(args.local_rank))
        print("Resuming from epoch {}".format(last_epoch))
        model.load_state_dict(check_point['model'])
        optimizer.load_state_dict(check_point['optimizer'])
        increment = last_epoch + 1
    
    if mixed_precision:
        fp16_scaler = GradScaler()
    else:
        fp16_scaler = None
    
    if utils.is_main_process():
        train_loss, val_loss = train(model, train_dl, val_dl, criterion,
                                     optimizer, lr_schedule,
                                     args.epochs, increment, model_save_dir, fp16_scaler, args)

        np.save(os.path.join(save_dir, "train_loss.npy"), train_loss)
        np.save(os.path.join(save_dir, "val_loss.npy"), val_loss)
    else:
        train(model, train_dl, val_dl, criterion,
              optimizer, lr_schedule, args.epochs, increment, model_save_dir, fp16_scaler, args)

    
if __name__ == '__main__':
    parser = ArgumentParser('train_segmentation', parents=[get_args_parser()])
    args = parser.parse_args()
    
    save_dir = args.save_dir
    if args.clear_save_dir:
        shutil.rmtree(save_dir, ignore_errors=True)

    Path(args.save_dir).mkdir(parents=True, exist_ok=True)
    train_and_save(args)
