import os
import torch
import shutil
import logging
from tqdm import tqdm
import horovod.torch as hvd
from core.metric import class_eval
from core.meters import AverageMeter
from core.loss_func import DockingLoss
from core.dataset import batch_to_device
from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts

class DockingTrainer():
    def __init__(self, data_loaders, model, config):

        self.config = config
        self.fp16 = config.fp16
        self.train_loader, self.valid_loader, self.test_loader, self.train_sampler = data_loaders
        self.scaler = torch.cuda.amp.GradScaler()
        self.criterion = DockingLoss(self.config).cuda()
        self.best_performance = 0
        self.metric = config.metric
        self.epochs = config.epochs
        self.test_freq = config.test_freq
        self.clip_grad_norm = config.clip_grad_norm
        self.checkpoint_dir = f'./{config.run_name}_checkpoints/'
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.hadoop_dir = config.hadoop_dir
        self.fine_tune = config.fine_tune

        # model
        self.model = model
        if (config.serial or hvd.rank() == 0) and not self.fine_tune:
            logging.info(self.model)
            learnable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            print(f'\nNumber of learnable model parameters: {learnable_params}')

        self.model.cuda()
        
        # optimizer
        if config.optimizer == 'Adam':
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=config.lr,
                                              betas=(0.9, 0.999))
        elif config.optimizer == 'SGD':
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=config.lr,
                                             momentum=0.9)
        else:
            raise NotImplementedError

        if not config.serial:
            # broadcast parameters & optimizer state
            hvd.broadcast_optimizer_state(self.optimizer, root_rank=0)
            hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
            # wrap optimizer with DistributedOptimizer
            compression = hvd.Compression.fp16 if self.fp16 else hvd.Compression.none
            self.optimizer = hvd.DistributedOptimizer(self.optimizer, compression=compression, 
                                                                      named_parameters=self.model.named_parameters())
            
        # LR scheduler
        if config.lr_scheduler == 'CosineAnnealingWarmRestarts':
            self.scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer,
                                                         T_0=config.lr_t0,
                                                         T_mult=config.lr_tmult,
                                                         eta_min=config.lr_eta_min)
        elif config.lr_scheduler == 'StepLR':
            self.scheduler = StepLR(optimizer=self.optimizer, 
                                    step_size=config.lr_step_size,
                                    gamma=config.lr_gamma)
        else:
            raise NotImplementedError

        if self.fine_tune:
            if config.serial or hvd.rank() == 0:
                print(f'=> loading model from {self.checkpoint_dir}')
            checkpoint = torch.load(os.path.join(self.checkpoint_dir, 'model_pretrain_best.pt'))
            self.model.load_state_dict(checkpoint['state_dict'])
        
        # upload results to HDFS
        if self.hadoop_dir:
            os.system(f'hadoop fs -mkdir -p {self.hadoop_dir}')

    def train(self):
        for epoch in range(1, self.epochs+1):
            torch.cuda.empty_cache()
            lr = self.scheduler.get_last_lr()
            if self.config.serial or hvd.rank() == 0:
                logging.info('Epoch: {}, LR: {:.6f}'.format(epoch, lr[0]))
            
            self._train_epoch(epoch)
            performance = self._valid_epoch(test_mode=False)
            
            self.scheduler.step()
            is_best = performance > self.best_performance
            self.best_performance = max(performance, self.best_performance)
            
            # save checkpoint
            if is_best and (self.config.serial or hvd.rank() == 0):
                self._save_checkpoint(epoch, is_best=True)
            
            # predict on test set using the latest model
            if epoch % self.test_freq == 0:
                torch.cuda.empty_cache()
                if self.config.serial or hvd.rank() == 0:
                    print('Evaluating the latest model on test set')
                self._valid_epoch(test_mode=True)
                self._save_checkpoint(epoch, is_best=False)
        
        # evaluate best model on test set
        torch.cuda.empty_cache()
        if self.config.serial or hvd.rank() == 0:
            print('\n---------Evaluate Best Model on Test Set---------------', flush=True)
        phase = 'fine_tune' if self.fine_tune else 'pretrain'
        best_checkpoint = torch.load(os.path.join(self.checkpoint_dir, f'model_{phase}_best.pt'))
        self.model.load_state_dict(best_checkpoint['state_dict'])
        self._valid_epoch(test_mode=True)

    def _save_checkpoint(self, epoch, is_best=False):
        state_dict = {
            'state_dict': self.model.state_dict(),
        }
        phase = 'fine_tune' if self.fine_tune else 'pretrain'
        if is_best:
            filename = os.path.join(self.checkpoint_dir, f'model_{phase}_best.pt')
        else:
            filename = os.path.join(self.checkpoint_dir, f'model_{phase}_epoch{epoch}.pt')
        torch.save(state_dict, filename)
        if self.hadoop_dir:
            src = self.checkpoint_dir
            dst = self.hadoop_dir
            os.system(f'hadoop fs -put -f {src} {dst}')    

    def _train_epoch(self, epoch):
        # init average meters
        lig_bsp_losses = AverageMeter('LigBSPLoss', ':5.3f')
        lig_precisions = AverageMeter('LigPrec', ':5.3f')
        lig_recalls = AverageMeter('LigRec', ':5.3f')
        lig_fscores = AverageMeter('LigFsc', ':5.3f')
        lig_AUCs = AverageMeter('LigAUC', ':5.3f')
        lig_APs = AverageMeter('LigAP', ':5.3f')
        rec_bsp_losses = AverageMeter('RecBSPLoss', ':5.3f')
        rec_precisions = AverageMeter('RecPrec', ':5.3f')
        rec_recalls = AverageMeter('RecRec', ':5.3f')
        rec_fscores = AverageMeter('RecFsc', ':5.3f')
        rec_AUCs = AverageMeter('RecAUC', ':5.3f')
        rec_APs = AverageMeter('RecAP', ':5.3f')
        lig_attn_losses = AverageMeter('LigAttnLoss', ':5.3f')
        rec_attn_losses = AverageMeter('RecAttnLoss', ':5.3f')
        nce_losses = AverageMeter('NCELoss', ':5.3f')
        
        # train model
        self.model.train()
        if not self.config.serial:
            # Horovod: set epoch to sampler for shuffling.
            self.train_sampler.set_epoch(epoch)
        for batch in tqdm(self.train_loader):
            # send data to device and compute model output
            batch = batch_to_device(batch)
            self.optimizer.zero_grad()
            if self.fp16:
                with torch.cuda.amp.autocast():
                    output = self.model(batch)
                    lig_bsp_loss, rec_bsp_loss, lig_attn_loss, rec_attn_loss, nce_loss = self.criterion(output)
                    total_loss = lig_bsp_loss + rec_bsp_loss + \
                                 -1 * self.config.attn_loss_weight * lig_attn_loss + \
                                 -1 * self.config.attn_loss_weight * rec_attn_loss + \
                                 self.config.nce_loss_weight * nce_loss
                                 
            else:
                output = self.model(batch)
                lig_bsp_loss, rec_bsp_loss, lig_attn_loss, rec_attn_loss, nce_loss = self.criterion(output)
                total_loss = lig_bsp_loss + rec_bsp_loss + \
                             -1 * self.config.attn_loss_weight * lig_attn_loss + \
                             -1 * self.config.attn_loss_weight * rec_attn_loss + \
                             self.config.nce_loss_weight * nce_loss
            
            lig_precision, lig_recall, lig_fscore, lig_AUC, lig_AP = class_eval(output['lig_dict'])
            if output['rec_dict']['bsp'] is not None:
                rec_precision, rec_recall, rec_fscore, rec_AUC, rec_AP = class_eval(output['rec_dict'])
            else:
                rec_precision, rec_recall, rec_fscore, rec_AUC, rec_AP = torch.zeros((5, 1))
            bsize = len(batch['lig_dict']['num_verts'])
            lig_bsp_losses.update(lig_bsp_loss.item(), bsize)
            lig_precisions.update(sum(lig_precision)/bsize, bsize)
            lig_recalls.update(sum(lig_recall)/bsize, bsize)
            lig_fscores.update(sum(lig_fscore)/bsize, bsize)
            lig_AUCs.update(sum(lig_AUC)/bsize, bsize)
            lig_APs.update(sum(lig_AP)/bsize, bsize)
            rec_bsp_losses.update(rec_bsp_loss.item(), bsize)
            rec_precisions.update(sum(rec_precision)/bsize, bsize)
            rec_recalls.update(sum(rec_recall)/bsize, bsize)
            rec_fscores.update(sum(rec_fscore)/bsize, bsize)
            rec_AUCs.update(sum(rec_AUC)/bsize, bsize)
            rec_APs.update(sum(rec_AP)/bsize, bsize)
            lig_attn_losses.update(lig_attn_loss.item(), bsize)
            rec_attn_losses.update(rec_attn_loss.item(), bsize)
            nce_losses.update(nce_loss.item(), bsize)

            # compute gradient and optimize
            if self.fp16:
                self.scaler.scale(total_loss).backward()
                # Make sure all async allreduces are done
                if not self.config.serial:
                    self.optimizer.synchronize()
                grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
                # In-place unscaling of all gradients before weights update
                self.scaler.unscale_(self.optimizer)
                if not self.config.serial:
                    with self.optimizer.skip_synchronize():
                        self.scaler.step(self.optimizer)
                else:
                    self.scaler.step(self.optimizer)
                # Update scaler in case of overflow/underflow
                self.scaler.update()
            else:
                total_loss.backward()
                if not self.config.serial:
                    self.optimizer.synchronize()
                grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
                self.optimizer.step()
            
            if self.config.serial or hvd.rank() == 0:
                if grad_norm > 10 * self.clip_grad_norm and not self.fp16:
                    print(f'\n!!! gradient norm exploded: {grad_norm:.2f} !!!', flush=True)
        
        print_info = [f'***** Training\n', 
                      f'LigBSPLoss: {self.avg(lig_bsp_losses.avg):.3f}, ',
                      f'LigPrec: {self.avg(lig_precisions.avg):.3f}, LigRec: {self.avg(lig_recalls.avg):.3f}, ',
                      f'LigFsc: {self.avg(lig_fscores.avg):.3f}, LigAUC: {self.avg(lig_AUCs.avg):.3f}, ',
                      f'LigAP: {self.avg(lig_APs.avg):.3f}\n',
                      f'RecBSPLoss: {self.avg(rec_bsp_losses.avg):.3f}, ',
                      f'RecPrec: {self.avg(rec_precisions.avg):.3f}, RecRec: {self.avg(rec_recalls.avg):.3f}, ',
                      f'RecFsc: {self.avg(rec_fscores.avg):.3f}, RecAUC: {self.avg(rec_AUCs.avg):.3f}, ',
                      f'RecAP: {self.avg(rec_APs.avg):.3f}\n',
                      f'LigAttnLoss: {self.avg(lig_attn_losses.avg):.3f}\n',
                      f'RecAttnLoss: {self.avg(rec_attn_losses.avg):.3f}\n',
                      f'NCELoss: {self.avg(nce_losses.avg):.3f}\n',
                      f'*****\n']
        if self.config.serial or hvd.rank() == 0:
            print(''.join(print_info))

    def avg(self, val):
        if self.config.serial:
            return val
        tensor = torch.tensor(val)
        avg_tensor = hvd.allreduce(tensor)
        return avg_tensor.item()

    def _valid_epoch(self, test_mode=False):
        # init average meters
        lig_bsp_losses = AverageMeter('LigBSPLoss', ':5.3f')
        lig_precisions = AverageMeter('LigPrec', ':5.3f')
        lig_recalls = AverageMeter('LigRec', ':5.3f')
        lig_fscores = AverageMeter('LigFsc', ':5.3f')
        lig_AUCs = AverageMeter('LigAUC', ':5.3f')
        lig_APs = AverageMeter('LigAP', ':5.3f')
        rec_bsp_losses = AverageMeter('RecBSPLoss', ':5.3f')
        rec_precisions = AverageMeter('RecPrec', ':5.3f')
        rec_recalls = AverageMeter('RecRec', ':5.3f')
        rec_fscores = AverageMeter('RecFsc', ':5.3f')
        rec_AUCs = AverageMeter('RecAUC', ':5.3f')
        rec_APs = AverageMeter('RecAP', ':5.3f')
        lig_attn_losses = AverageMeter('LigAttnLoss', ':5.3f')
        rec_attn_losses = AverageMeter('RecAttnLoss', ':5.3f')
        nce_losses = AverageMeter('NCELoss', ':5.3f')

        data_loader = self.test_loader if test_mode else self.valid_loader

        # evaluate model
        self.model.eval()
        with torch.no_grad():
            for batch in tqdm(data_loader):
                # send data to device and compute model output
                batch = batch_to_device(batch)
                output = self.model(batch)
                
                lig_bsp_loss, rec_bsp_loss, lig_attn_loss, rec_attn_loss, nce_loss = self.criterion(output)
                lig_precision, lig_recall, lig_fscore, lig_AUC, lig_AP = class_eval(output['lig_dict'])
                if output['rec_dict']['bsp'] is not None:
                    rec_precision, rec_recall, rec_fscore, rec_AUC, rec_AP = class_eval(output['rec_dict'])
                else:
                    rec_precision, rec_recall, rec_fscore, rec_AUC, rec_AP = torch.zeros((5, 1))
                bsize = len(batch['lig_dict']['num_verts'])
                lig_bsp_losses.update(lig_bsp_loss.item(), bsize)
                lig_precisions.update(sum(lig_precision)/bsize, bsize)
                lig_recalls.update(sum(lig_recall)/bsize, bsize)
                lig_fscores.update(sum(lig_fscore)/bsize, bsize)
                lig_AUCs.update(sum(lig_AUC)/bsize, bsize)
                lig_APs.update(sum(lig_AP)/bsize, bsize)
                rec_bsp_losses.update(rec_bsp_loss.item(), bsize)
                rec_precisions.update(sum(rec_precision)/bsize, bsize)
                rec_recalls.update(sum(rec_recall)/bsize, bsize)
                rec_fscores.update(sum(rec_fscore)/bsize, bsize)
                rec_AUCs.update(sum(rec_AUC)/bsize, bsize)
                rec_APs.update(sum(rec_AP)/bsize, bsize)
                lig_attn_losses.update(lig_attn_loss.item(), bsize)
                rec_attn_losses.update(rec_attn_loss.item(), bsize)
                nce_losses.update(nce_loss.item(), bsize)

        mode='Test' if test_mode else 'Valid'
        print_info = [f'***** {mode}\n', 
                      f'LigBSPLoss: {self.avg(lig_bsp_losses.avg):.3f}, ',
                      f'LigPrec: {self.avg(lig_precisions.avg):.3f}, LigRec: {self.avg(lig_recalls.avg):.3f}, ',
                      f'LigFsc: {self.avg(lig_fscores.avg):.3f}, LigAUC: {self.avg(lig_AUCs.avg):.3f}, ',
                      f'LigAP: {self.avg(lig_APs.avg):.3f}\n',
                      f'RecBSPLoss: {self.avg(rec_bsp_losses.avg):.3f}, ',
                      f'RecPrec: {self.avg(rec_precisions.avg):.3f}, RecRec: {self.avg(rec_recalls.avg):.3f}, ',
                      f'RecFsc: {self.avg(rec_fscores.avg):.3f}, RecAUC: {self.avg(rec_AUCs.avg):.3f}, ',
                      f'RecAP: {self.avg(rec_APs.avg):.3f}\n',
                      f'LigAttnLoss: {self.avg(lig_attn_losses.avg):.3f}\n',
                      f'RecAttnLoss: {self.avg(rec_attn_losses.avg):.3f}\n',
                      f'NCELoss: {self.avg(nce_losses.avg):.3f}\n',
                      f'*****\n']
        if self.config.serial or hvd.rank() == 0:
            print(''.join(print_info))
        
        if self.metric == 'AP':
            metric = self.avg(lig_APs.avg) + self.avg(rec_APs.avg)
        else:
            assert self.metric == 'AUC'
            metric = self.avg(lig_AUCs.avg) + self.avg(rec_AUCs.avg)

        return metric


