import os 
import json
import torch

from torch.utils.tensorboard import SummaryWriter
import dataset.protein_data as protein_data
import models.protein_model as protein_model
import utils.dist_util as dist_util
import torch.nn.functional as F
import options.option_main as option_main
import utils.utils as utils_main
import warnings
warnings.filterwarnings('ignore')

##### ---- Exp dirs ---- #####
args = option_main.get_args_parser()

args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')

#### ---- DistributedDataParallel ---- #####
device = dist_util.distributions_init(args)

##### ---- Logger ---- #####
logger = utils_main.get_logger(args.out_dir)
writer = SummaryWriter(args.out_dir)
if args.local_rank == 0:
    logger.info(json.dumps(vars(args), indent=4, sort_keys=True))

##### ---- Dataloader ---- #####
train_loader, val_loader, translate_loader = protein_data.get_dataloader(args)
train_loader_iter = protein_data.cycle(train_loader)

##### ---- Network ---- #####
net = protein_model.get_model(args)
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)

if args.resume_pth is not None:
    print ('loading checkpoint from {}'.format(args.resume_pth))
    ckpt = torch.load(args.resume_pth, map_location='cpu')
    net.load_state_dict(ckpt['net'], strict=True)
net.cuda()
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.local_rank], output_device=args.local_rank)

##### ---- Optimizer & Scheduler ---- #####
optimizer = utils_main.initial_optim(args.lr, args.weight_decay, net, args.optimizer)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma)


##### ---- Training ---- #####
avg_loss_cls = 0.
acc_num, num_sample = 0., 0.
nb_iter = 0
ignore_index = -1
while nb_iter <= args.total_iter:

    batch = next(train_loader_iter)
    _, _, protein_seq, func_text, _ = batch
    protein_seq, func_text = protein_seq.cuda(), func_text.cuda()
    bs = protein_seq.shape[0]
    gt_mask = (func_text != 0)   # [PAD] = 1, others = 0
    labels = torch.where(gt_mask, func_text, ignore_index)

    global_feature, protein_feature, function_logits = net(protein_seq, func_text[:, :-1])
    loss_cls = F.cross_entropy(function_logits.transpose(-2, -1), labels[:, 1:], ignore_index=ignore_index)
    
    _, cls_pred_index = torch.max(function_logits, dim=-1)
    acc_mask = (cls_pred_index == labels[:, 1:])
    acc_mask = torch.where(gt_mask[:, 1:], acc_mask, False)

    ## global loss
    optimizer.zero_grad()
    loss_cls.backward()
    optimizer.step()
    scheduler.step()

    acc_num = acc_num + acc_mask.sum().item()
    num_sample = num_sample + gt_mask[:, 1:].sum().item()
    avg_loss_cls = avg_loss_cls + loss_cls.item()

    nb_iter += 1
    if nb_iter % args.print_iter ==  0 :
        if args.local_rank == 0:  
            avg_loss_cls = avg_loss_cls / args.print_iter
            avg_acc = acc_num * 100 / num_sample
            writer.add_scalar('./Loss/train', avg_loss_cls, nb_iter)
            writer.add_scalar('./ACC/train', avg_acc, nb_iter)
            msg = f"Train. Iter {nb_iter} : Loss. {avg_loss_cls:.5f}, ACC. {avg_acc:.4f}"
            logger.info(msg)
        avg_loss_cls = 0.
        acc_num = 0.
        num_sample = 0.
