import os
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter

__all__ = ['ANPTraniner']

class Trainer(object):
    def __init__(self, model, optimizer, scheduler, trainloader, validloader, testloader=None, args=None):
        self.args = args
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.trainloader = trainloader
        self.validloader = validloader
        self.testloader  = testloader
        self.start_epoch = 0

        self.init_paths()

        if args.resume:
            self.resume()
    
    def init_paths(self):
        if not os.path.exists('checkpoints'):
            os.makedirs('checkpoints')
        if not os.path.exists('checkpoints/{}/{}'.format(self.args.dataset, self.args.stage)):
            os.makedirs('checkpoints/{}/{}'.format(self.args.dataset, self.args.stage))
    
    def save_checkpoint(self, epoch):
        checkpoint = {
                'epoch': epoch,
                'model': self.model.module.state_dict() if isinstance(self.model, nn.DataParallel) else self.model.state_dict(),
                'optimizer': self.optimizer.module.state_dict() if isinstance(self.optimizer, nn.DataParallel) else self.optimizer.state_dict(),
                'scheduler': self.scheduler.module.state_dict() if isinstance(self.scheduler, nn.DataParallel) else self.scheduler.state_dict(),
                }
        torch.save(checkpoint, 'checkpoints/{}/{}/{}.pth'.format(self.args.dataset, self.args.stage, self.args.run))
    
    def resume(self):
        checkpoint = torch.load('checkpoints/{}/{}/{}.pth'.format(self.args.dataset, self.args.stage, self.args.run))
        self.start_epoch = checkpoint['epoch'] + 1
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        print('Resuming from epoch : {}'.format(self.start_epoch))

    def train(self):
        raise NotImplementedError
    
    def valid(self):
        raise NotImplementedError
    
    def fit(self):
        raise NotImplementedError

class ANPTraniner(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if self.args.mode == 'train' and not self.args.debug:
            self.trainwriter = SummaryWriter(log_dir='results/{}/{}/train/{}'.format(self.args.dataset, self.args.stage, self.args.run))
        else:
            self.validwriter = SummaryWriter(log_dir='results/{}/{}/valid/{}'.format(self.args.dataset, self.args.stage, self.args.run))
    
    def train(self):
        self.model.train()
        train_stats = OrderedDict()
        for it, (context, target) in tqdm(enumerate(self.trainloader), total=len(self.trainloader), ncols=75, leave=False):
            context, target = context.to(self.args.device), target.to(self.args.device)
            output = self.model(context=context, target=target, subset_size=self.args.max_output_points)
            self.optimizer.zero_grad()
            
            batch_stats = OrderedDict()
            if self.args.stage == 'candidate':
                log_p, log_p_real, candidate_mask, candidate_mask_real, candidate_reg = output
                loss_p              = -log_p.sum(dim=1).mean()
                loss_p_real         = -log_p_real.sum(dim=1).mean()
                candidate_reg       = candidate_reg.sum(dim=1).mean()
                loss = loss_p + loss_p_real + candidate_reg
                
                batch_stats['RCLoss']               = loss_p.item()
                batch_stats['RCLossReal']           = loss_p_real.item()
                batch_stats['CandidateReg']         = candidate_reg.item()
                batch_stats['CandidateMask']        = candidate_mask.sum().item() / np.prod(candidate_mask.shape)
                batch_stats['CandidateMaskReal']    = candidate_mask_real.sum().item() / np.prod(candidate_mask_real.shape)
            elif self.args.stage == 'autoregressive':
                log_p, log_p_real, subset_mask, subset_mask_real = output
                loss_p                              = -log_p.sum(dim=1).mean()
                loss_p_real                         = -log_p_real.sum(dim=1).mean()
                loss = loss_p + loss_p_real
                
                batch_stats['RCLoss']               = loss_p.item()
                batch_stats['RCLossReal']           = loss_p_real.item()
                batch_stats['SubsetMask']           = subset_mask.sum().item() / np.prod(subset_mask.shape)
                batch_stats['SubsetMaskReal']       = subset_mask_real.sum().item() / np.prod(subset_mask_real.shape)
            elif self.args.stage == 'sss':
                log_p, log_p_real, candidate_mask, candidate_mask_real, candidate_reg, subset_mask, subset_mask_real = output
                loss_p                              = -log_p.sum(dim=1).mean()
                loss_p_real                         = -log_p_real.sum(dim=1).mean()
                candidate_reg                       = candidate_reg.sum(dim=1).mean()
                loss = loss_p + loss_p_real + candidate_reg
                
                batch_stats['RCLoss']               = loss_p.item()
                batch_stats['RCLossReal']           = loss_p_real.item()
                batch_stats['CandidateReg']         = candidate_reg.item()
                batch_stats['CandidateMask']        = candidate_mask.sum().item() / np.prod(candidate_mask.shape)
                batch_stats['CandidateMaskReal']    = candidate_mask_real.sum().item() / np.prod(candidate_mask_real.shape)
                batch_stats['SubsetMask']           = subset_mask.sum().item() / np.prod(subset_mask.shape)
                batch_stats['SubsetMaskReal']       = subset_mask_real.sum().item() / np.prod(subset_mask_real.shape)
            elif self.args.stage == 'randomautoregressive':
                log_p, log_p_real, random_mask, subset_mask, subset_mask_real = output
                loss_p                              = -log_p.sum(dim=1).mean()
                loss_p_real                         = -log_p_real.sum(dim=1).mean()
                loss = loss_p  + loss_p_real
                
                batch_stats['RCLoss']               = loss_p.item()
                batch_stats['RCLossReal']           = loss_p_real.item()
                batch_stats['RandomMask']           = random_mask.sum().item() / np.prod(random_mask.shape)
                batch_stats['SubsetMask']           = subset_mask.sum().item() / np.prod(subset_mask.shape)
                batch_stats['SubsetMaskReal']       = subset_mask_real.sum().item() / np.prod(subset_mask_real.shape)
            elif self.args.stage == 'random':
                log_p, random_mask = output
                log_p = -log_p.sum(dim=1).mean()
                loss = log_p

                batch_stats['RCLoss'] = log_p.item()
                batch_stats['RandomMask'] = random_mask.sum().item() / np.prod(random_mask.shape)
            else:
                raise NotImplementedError('{} not implemented'.format(self.args.stage))

            loss.backward()
            self.optimizer.step()
            
            for key in batch_stats:
                if key in train_stats:
                    train_stats[key].append(batch_stats[key])
                else:
                    train_stats[key] = [batch_stats[key]]
        return train_stats

    def fit(self):
        for epoch in range(self.start_epoch, self.args.epochs):
            train_stats = self.train()
            self.scheduler.step()
            
            print('Epoch: {:<3}'.format(epoch))
            for key in train_stats:
                value = np.mean(train_stats[key])
                print('\t\t{:<40} : {:.4f}'.format(key, value))
                self.trainwriter.add_scalar(key, value, epoch)
            
            self.save_checkpoint(epoch=epoch)
