import logging
from concurrent.futures import ThreadPoolExecutor
import os
import time
import signal

import wandb
import torch
from torch import nn
from omegaconf import DictConfig
import torch.distributed as dist

from il_scale.nethack.trainers.trainer import Trainer
from il_scale.nethack.agent import Agent
from il_scale.nethack.data.tty_data import TTYData
from il_scale.nethack.utils.setup import DDPUtil
from il_scale.nethack.utils.model import count_params
from il_scale.nethack.logger import Logger

# A logger for this file
logging.basicConfig(
    format=(
        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
    ),
    level=logging.INFO,
)

class BCTrainer(Trainer):
    def __init__(
        self, 
        cfg: DictConfig, 
        logger: Logger, 
        agent: Agent, 
        data: TTYData,
        ddp_util: DDPUtil
    ):
        self.agent = agent
        
        super(BCTrainer, self).__init__(cfg, logger, data, ddp_util)

        self.num_model_params = count_params(agent.model)
        self.effective_model_params = sum([
            count_params(agent.model.core),
            count_params(agent.model.policy_head),
            count_params(agent.model.modality_mixer)
        ])

        self.time_to_clean = False # flag to clean up before exiting

    ###### INTERFACE ######

    def signal_handler(self, signum, frame):
        logging.info(f"Process {self.ddp_util.rank} received signal {signum}")
        self.time_to_clean = True

    def train(self):
        # register a signal for saving
        signal.signal(signal.SIGUSR2, self.signal_handler)

        self._reset()
        self.logger.start()
        self.agent.train()

        agent_state = self.agent.initial_state(self.data.train_batch_size, self.ddp_util.rank)

        if self.cfg.network.core_mode == 'transformer_xl':
            attn_mask = torch.zeros((self.data.train_seq_len, self.data.train_seq_len + self.cfg.network.tf_mem_len, self.data.train_batch_size)).to(self.ddp_util.rank)

        max_workers = self.cfg.data.workers
        with ThreadPoolExecutor(max_workers=max_workers) as tp:
            # Retrieve training data
            train_data = self.data.get_train_dataloader(tp, self.ddp_util.rank, self.ddp_util.world_size)

            # Start training loop
            logging.info(f"Processing {len(train_data.gameids)} gameids.")
            for i, batch in enumerate(train_data, 1):    
                with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.use_amp):

                    if self.cfg.network.core_mode == 'transformer_xl':
                        for b in range(self.data.train_batch_size):
                            if torch.any(batch['done'][:, b, ...]):
                                done_idxs = torch.argwhere(batch['done'][:, b, ...].long().flatten()).flatten()
                                for idx in done_idxs:
                                    attn_mask[idx:, :self.cfg.network.tf_mem_len + idx, b] = 1.0
                    else:
                        attn_mask = None

                    agent_outputs, agent_state = self.agent.predict(batch, agent_state=agent_state, attn_mask=attn_mask)

                    if self.cfg.network.core_mode == 'transformer_xl':
                        # shift the attention mask by self.data.train_seq_len
                        attn_mask = torch.roll(attn_mask, -self.data.train_seq_len, dims=1)
                        attn_mask[:, -self.data.train_seq_len:, :] = 0.0

                    if self.cfg.network.core_mode == 'lstm':
                        agent_state = (agent_state[0].detach(), agent_state[1].detach())

                    # Reshape logits
                    T, B = agent_outputs['policy_logits'].shape[:2]
                    logits = agent_outputs['policy_logits'].view(B * T, -1)

                    # Loss and gradients
                    labels = batch['labels'].contiguous().view(B * T)
                    loss = self.criterion(logits, labels) / self.cfg.trainer.gradient_acc

                self.scaler.scale(loss).backward()

                self.logger.update_metrics(B, T, loss * self.cfg.trainer.gradient_acc, logits, labels, labels.shape[0], batch, i)
                self.logger.sample_step(labels.shape[0])

                self.seen_gameids[batch['unique_gameids']] = 1

                if i % self.cfg.trainer.gradient_acc != 0:
                    continue

                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self._get_model().parameters(), self.cfg.trainer.clip)
                self.scaler.step(self.optimizer)
                self.scheduler.step()

                self.scaler.update()
                self.optimizer.zero_grad(set_to_none=True)
                self.logger.gradient_step()
                
                if self.logger.grad_steps % self.cfg.trainer.log_freq == 0:
                    self.logger.log_train(self.ddp_util.rank, self.scheduler.get_last_lr()[0])
                    self.logger.reset()

                # save checkpoint every 1B
                if not self.saved_chkpts[int(self._get_total_samples() // self.cfg.trainer.chkpt_freq)]:
                    logging.info('Entered saving checkpoint block ...')

                    # sync seen gameids before saving
                    self._sync_gameids(self.ddp_util.rank)

                    if self.ddp_util.rank == 0:
                        chkpt_num = int(self._get_total_samples() // self.cfg.trainer.chkpt_freq)
                        self.saved_chkpts[chkpt_num] = 1
                        self._save(f"model_{chkpt_num}.tar")

                    # sync saved checkpoints
                    dist.barrier()
                    dist.all_reduce(self.saved_chkpts, op=dist.ReduceOp.MAX)

                # save checkpoint for every flop
                if self.ddp_util.rank == 0 and self.cfg.trainer.save_flops:
                    total_flops = 6 * self.effective_model_params * self._get_total_samples() # 6ND approx.
                    for flop in self.is_flop_saved.keys():
                        if total_flops >= flop and not self.is_flop_saved[flop]:
                            self._save(f"model_{self.FLOP_TO_STR[flop]}.tar")
                            self.is_flop_saved[flop] = True

                if self.time_to_clean:
                    logging.info('Hit time to clean!')
                    # sync seen gameids before saving
                    self._sync_gameids(self.ddp_util.rank)
                    logging.info('Done syncing!')

                    if self.ddp_util.rank == 0:
                        logging.info('Saving last checkpoint of this run ...')
                        self._save(f"model_last.tar")
                    dist.barrier()

                    if not self.logger.just_reset:
                        self.logger.log_train(self.ddp_util.rank)
                        self.logger.reset()
                    dist.barrier()

                    break

                # Stop training if we have seen enough samples
                if self._stop_condition():
                    if not self.logger.just_reset:
                        self.logger.log_train(self.ddp_util.rank)
                        self.logger.reset()

                    break

        logging.info("Done training")

    ###### PRIVATE ######

    def _sync_gameids(self, rank):
        logging.info(f'Syncing gameids in rank {rank} ...')
        dist.barrier()
        logging.info('Reducing gameids ...')
        dist.all_reduce(self.seen_gameids, op=dist.ReduceOp.MAX)
        logging.info(f'Gameids successfully synced in rank {rank}!')

    def _get_model(self):
        return self.agent.model.module if self.agent.ddp else self.agent.model
            
    def _load_weights(self, state_dict):
        self.agent.load(state_dict)