
import os
import torch
import torch.distributed as dist

from ..vocabularies import GeneralGamestateDeltaVocabByBen


class BaseRunner(object):

    def run(self):
        for epoch in range(self.start_epoch, self.n_epochs):
            if self.rank == 0:
                print('Performing epoch {} of {}'.format(epoch, self.n_epochs))

            if self.mode == 'train' or self.mode == 'pt':
                self.model.train()
            else:
                self.model.eval()

            if self.mode == 'train' or self.mode == 'pt':
                print('&& Running with grad &&')
                self.run_one_epoch(epoch, self.mode)
            else:
                with torch.no_grad():
                    self.run_one_epoch(epoch, self.mode)

            if self.rank == 0:
                print('Done with epoch {} of {}'.format(epoch, self.n_epochs))

            if self.mode == 'train' or self.mode == 'pt':
                if epoch % self.save_model_every == 0 or epoch == self.n_epochs - 1 and self.save_model_every > 0:
                    if self.rank == 0:
                        print('Saving model...')
                        if not self.args.on_cpu:
                            torch.save(self.model.module.state_dict(),
                                       os.path.join(self.args.model_save_dir, self.args.ckpt_file_tmplt.format(epoch)))
                        else:
                            torch.save(self.model.state_dict(),
                                       os.path.join(self.args.model_save_dir, self.args.ckpt_file_tmplt.format(epoch)))
                    # dist.barrier()

                if self.args.dev_every > 0 and (epoch % self.args.dev_every == 0 or epoch == self.n_epochs - 1):
                    if self.rank == 0:
                        print('Performing train-dev for epoch {} of {}'.format(epoch, self.n_epochs))
                    self.model.eval()
                    with torch.no_grad():
                        self.run_one_epoch(epoch, 'train-dev')
            elif self.force_model_save and self.rank == 0:
                print('FORCE SAVING MODEL')
                print('\tself.force_model_save_fp: {}'.format(self.force_model_save_fp))
                if not self.args.on_cpu:
                    torch.save(self.model.module.state_dict(),
                               self.force_model_save_fp)
                else:
                    torch.save(self.model.state_dict(),
                               self.force_model_save_fp)

    def make_gamestate_vocab(self):
        gamestate_vocab_bos_inning_no = getattr(self.args, 'gamestate_vocab_bos_inning_no', False)
        gamestate_vocab_bos_score_diff = getattr(self.args, 'gamestate_vocab_bos_score_diff', False)
        gamestate_vocab_bos_base_occ = getattr(self.args, 'gamestate_vocab_bos_base_occ', True)
        gamestate_vocab_use_balls_strikes = getattr(self.args, 'gamestate_vocab_use_balls_strikes', True)
        gamestate_vocab_use_base_occupancy = getattr(self.args, 'gamestate_vocab_use_base_occupancy', True)
        gamestate_vocab_use_score_diff = getattr(self.args, 'gamestate_vocab_use_score_diff', True)
        gamestate_vocab_use_outs = getattr(self.args, 'gamestate_vocab_use_outs', True)
        gamestate_n_innings = getattr(self.args, 'gamestate_n_innings', 10)
        gamestate_max_score_diff = getattr(self.args, 'gamestate_max_score_diff', 6)
        gamestate_vocab_use_swing_status = getattr(self.args, 'gamestate_vocab_use_swing_status', False)
        self.gamestate_vocab = GeneralGamestateDeltaVocabByBen(
            bos_inning_no=gamestate_vocab_bos_inning_no, max_inning_no=gamestate_n_innings,
            bos_score_diff=gamestate_vocab_bos_score_diff, bos_max_score_diff=gamestate_max_score_diff,
            bos_base_occ=gamestate_vocab_bos_base_occ, balls_delta=gamestate_vocab_use_balls_strikes,
            strikes_delta=gamestate_vocab_use_balls_strikes, outs_delta=gamestate_vocab_use_outs,
            score_delta=gamestate_vocab_use_score_diff, base_occ_delta=gamestate_vocab_use_base_occupancy,
            swing_status=gamestate_vocab_use_swing_status
        )

        self.args.n_gamestate_tokens = len(self.gamestate_vocab)
        self.args.n_gamestate_bos_tokens = len(self.gamestate_vocab.bos_vocab)
        self.args.gsd_mask_id = self.gamestate_vocab.mask_id
        self.sep_id = self.gamestate_vocab.sep_id

    def make_gamestate_vocab_from_args(self, args):
        gamestate_vocab_bos_inning_no = getattr(args, 'gamestate_vocab_bos_inning_no', False)
        gamestate_vocab_bos_score_diff = getattr(args, 'gamestate_vocab_bos_score_diff', False)
        gamestate_vocab_bos_base_occ = getattr(args, 'gamestate_vocab_bos_base_occ', True)
        gamestate_vocab_use_balls_strikes = getattr(args, 'gamestate_vocab_use_balls_strikes', True)
        gamestate_vocab_use_base_occupancy = getattr(args, 'gamestate_vocab_use_base_occupancy', True)
        gamestate_vocab_use_score_diff = getattr(args, 'gamestate_vocab_use_score_diff', True)
        gamestate_vocab_use_outs = getattr(args, 'gamestate_vocab_use_outs', True)
        gamestate_n_innings = getattr(args, 'gamestate_n_innings', 10)
        gamestate_max_score_diff = getattr(args, 'gamestate_max_score_diff', 6)
        gamestate_vocab_use_swing_status = getattr(args, 'gamestate_vocab_use_swing_status', False)
        gamestate_vocab = GeneralGamestateDeltaVocabByBen(
            bos_inning_no=gamestate_vocab_bos_inning_no, max_inning_no=gamestate_n_innings,
            bos_score_diff=gamestate_vocab_bos_score_diff, bos_max_score_diff=gamestate_max_score_diff,
            bos_base_occ=gamestate_vocab_bos_base_occ, balls_delta=gamestate_vocab_use_balls_strikes,
            strikes_delta=gamestate_vocab_use_balls_strikes, outs_delta=gamestate_vocab_use_outs,
            score_delta=gamestate_vocab_use_score_diff, base_occ_delta=gamestate_vocab_use_base_occupancy,
            swing_status=gamestate_vocab_use_swing_status
        )

        args.n_gamestate_tokens = len(gamestate_vocab)
        args.n_gamestate_bos_tokens = len(gamestate_vocab.bos_vocab)
        args.gsd_mask_id = gamestate_vocab.mask_id
        args.sep_id = gamestate_vocab.sep_id

        return args, gamestate_vocab

    def gather(self, x):
        n_x = torch.tensor([x.shape[0]], device=x.device)
        n_x_list = [torch.zeros_like(n_x) for _ in range(self.world_size)]
        dist.all_gather(n_x_list, n_x)
        n_x = torch.cat(n_x_list, dim=0).contiguous()
        max_size = n_x.max() + 1

        indicator = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)

        if x.shape[0] != max_size:
            x_padding = torch.zeros(max_size - x.shape[0], *x.shape[1:], device=x.device, dtype=x.dtype)
            indicator_padding = torch.zeros(max_size - x.shape[0], device=x.device, dtype=torch.bool)

            x = torch.cat([x, x_padding], dim=0).contiguous()
            indicator = torch.cat([indicator, indicator_padding], dim=0).contiguous()

        x_list = [torch.zeros_like(x) for _ in range(self.world_size)]
        dist.all_gather(x_list, x)
        x = torch.cat(x_list, dim=0).contiguous()

        indicator_list = [torch.zeros_like(indicator) for _ in range(self.world_size)]
        dist.all_gather(indicator_list, indicator)
        indicator = torch.cat(indicator_list, dim=0).contiguous()

        x = x[indicator == 1]
        return x
