import logging
import os
import time

import wandb
import torch
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR, SequentialLR
from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup
from omegaconf import DictConfig, OmegaConf

from il_scale.nethack.logger import Logger
from il_scale.nethack.data.tty_data import TTYData
from il_scale.nethack.utils.setup import DDPUtil
from il_scale.nethack.utils.model import load_checkpoint

# A logger for this file
logging.basicConfig(
    format=(
        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
    ),
    level=logging.INFO,
)

class FakeScheduler():
    def __init__(self, lr:float):
        self.lr = lr

    def step(self):
        pass

    def get_last_lr(self):
        return [self.lr]

class Trainer():
    def __init__(
        self, 
        cfg: DictConfig, 
        logger: Logger, 
        data: TTYData, 
        ddp_util: DDPUtil
    ):
        self.cfg = cfg
        self.logger = logger
        self.data = data
        self.ddp_util = ddp_util

        self.is_flop_saved = {int(k): False for k in [1e14, 2e14, 5e14, 1e15, 2e15, 5e15, 1e16, 2e16, 5e16, 1e17, 2e17, 5e17, 1e18, 2e18, 3e18, 4e18, 5e18, 6e18, 7e18, 8e18, 9e18, 1e19]}

        self.FLOP_TO_STR = {
            1e14: "1e14",
            2e14: "2e14",
            5e14: "5e14",
            1e15: "1e15",
            2e15: "2e15",
            5e15: "5e15",
            1e16: "1e16",
            2e16: "2e16",
            5e16: "5e16",
            1e17: "1e17",
            2e17: "2e17",
            5e17: "5e17",
            1e18: "1e18",
            2e18: "2e18",
            3e18: "3e18",
            4e18: "4e18",
            5e18: "5e18",
            6e18: "6e18",
            7e18: "7e18",
            8e18: "8e18",
            9e18: "9e18",
            1e19: "1e19",
        }

        self.criterion = nn.CrossEntropyLoss()
        if self.cfg.optimizer.type == 'adam':
            self.optimizer = torch.optim.Adam(
                self._get_model().parameters(), 
                lr=self.cfg.optimizer.lr,
                weight_decay=self.cfg.optimizer.weight_decay
            )
        elif self.cfg.optimizer.type == 'adamw':
            print('using adamw optimizer ...')
            self.optimizer = torch.optim.AdamW(
                self._get_model().parameters(), 
                betas=(0.9, 0.95),
                eps=1e-5,
                lr=self.cfg.optimizer.lr,
                weight_decay=self.cfg.optimizer.weight_decay
            )
        else:
            raise NotImplementedError(f'Optimizer {self.cfg.optimizer.type} not implemented!')
        
        logging.info(f'LR: {self.cfg.optimizer.lr}')
        logging.info(f'Weight Decay: {self.cfg.optimizer.weight_decay}')

        self.total_steps = self.cfg.trainer.total_samples / (self.data.train_batch_size * self.data.train_seq_len * self.cfg.setup.num_gpus * self.cfg.trainer.gradient_acc)
        post_warmup_steps = self.total_steps - self.cfg.optimizer.optim_warmup_steps
        schedule_total_steps = self.cfg.optimizer.optim_warmup_steps + post_warmup_steps * 1/(1 - self.cfg.optimizer.lr_end_fraction)
        
        if self.cfg.optimizer.scheduler_type == 'linear':
            print('Using linear lr schedule ...')
            self.scheduler = get_linear_schedule_with_warmup(self.optimizer, self.cfg.optimizer.optim_warmup_steps, schedule_total_steps)
        
        elif self.cfg.optimizer.scheduler_type == 'constant':
            print(f'Using constant lr schedule with {self.cfg.optimizer.optim_warmup_steps} warmup steps ...')
            self.scheduler = get_constant_schedule_with_warmup(
                self.optimizer, 
                self.cfg.optimizer.optim_warmup_steps
            )

        elif self.cfg.optimizer.scheduler_type == "cosine":
            print('Using cosine lr schedule ...')
            scheduler1 = get_constant_schedule_with_warmup(self.optimizer, self.cfg.optimizer.optim_warmup_steps)
            scheduler2 = CosineAnnealingLR(self.optimizer, T_max=post_warmup_steps, eta_min=self.cfg.optimizer.lr_end_fraction * self.cfg.optimizer.lr)
            self.scheduler = SequentialLR(self.optimizer, schedulers=[scheduler1, scheduler2], milestones=[self.cfg.optimizer.optim_warmup_steps])

        self.use_amp = self.cfg.setup.use_amp
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

        self.train_min_loss = 1e10
        self.time_budget = self.cfg.trainer.train_time_budget * 60 * 60 # convert to seconds
        self.saved_chkpts = torch.zeros((1_000,), dtype=torch.int8).to(self.ddp_util.rank)
        self.saved_chkpts[0] = 1

        max_gameid = max(self.data.train_gameids)
        logging.info(f'Setting seen gameids max to {max_gameid + 1}')
        self.seen_gameids = torch.zeros((max_gameid + 1,), dtype=torch.int8).to(self.ddp_util.rank)

        self._maybe_resume()

    ###### INTERFACE ######

    def train(self):
        raise NotImplementedError()

    ###### PRIVATE ######

    def _reset(self):
        self.num_samples = 0
        self.train_min_loss = 1e10

    def _stop_condition(self):
        if time.time() - self.logger.start_time > self.time_budget:
            logging.info(f'Running out of time ...')
            return True
        elif self.logger.grad_steps > self.total_steps:
            logging.info(f'Running out of grad steps ...')
            return True
        else:
            return False
        
    def _get_total_samples(self):
        return self.logger.tot_samples + self.logger.log_samples * self.ddp_util.world_size

    def _save_chkpts(self, dev_metrics: dict, model_name: str = "model"):
        # Save if dev loss improves
        if dev_metrics['dev_loss'] < self.dev_min_loss:
            self.dev_min_loss = dev_metrics['dev_loss']
            # dev loss improved so reset patience
            self.patience = 0
            self._save(f"{model_name}_loss.tar")
        else:
            # dev loss didn't improve so increase patience
            self.patience += 1

        if dev_metrics['num_samples'] // self.cfg.trainer.chkpt_freq not in self.saved_chkpts:
            chkpt_num = dev_metrics['num_samples'] // self.cfg.trainer.chkpt_freq
            self._save(f"{model_name}_{chkpt_num}.tar")
            self.saved_chkpts.add(chkpt_num)

        # Save latest checkpoint always
        self._save(f"{model_name}_latest.tar")

    def _save(self, chkpt_name: str):
        checkpointpath = os.path.join('models', wandb.run.id, chkpt_name)
        if not os.path.exists(os.path.join('models', wandb.run.id)):
            os.makedirs(os.path.join('models', wandb.run.id))
        model = self._get_model()
        logging.info("Saving checkpoint to %s", checkpointpath)
        wandb_conf = OmegaConf.to_container(self.cfg, resolve=True, throw_on_missing=True)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "scheduler_state_dict": self.scheduler.state_dict(),
                "scaler_state_dict": self.scaler.state_dict(),
                "flags": wandb_conf,
                "num_samples": self.logger.tot_samples + self.logger.log_samples * self.ddp_util.world_size,
                "gradient_steps": self.logger.grad_steps,
                "train_min_loss": self.train_min_loss,
                "params": self.effective_model_params,
                "is_flop_saved": self.is_flop_saved,
                "seen_gameids": self.seen_gameids,
                "saved_chkpts": self.saved_chkpts,
            },
            checkpointpath,
        )
        wandb.save(checkpointpath)
        logging.info('Model saved!')

    def _maybe_resume(self):
        if self.cfg.setup.wandb_id:
            wandb_id = self.cfg.setup.wandb_id
            logging.info(f"Resuming state from wandb_id {wandb_id}")

            # Get checkpoint
            checkpoint = load_checkpoint(self.cfg.setup.model_load_name, wandb_id, overwrite=False, savedir=self.cfg.setup.wandb_load_dir)

            # Load weights for agents
            self._load_weights(checkpoint['model_state_dict'])
            logging.info(f"Loaded weights!")

            # Load trainer states
            logging.info(f"Loading optimizer ...")
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            if 'scaler_state_dict' in checkpoint:
                self.scaler.load_state_dict(checkpoint['scaler_state_dict'])

            if 'is_flop_saved' in checkpoint:
                self.is_flop_saved = checkpoint['is_flop_saved']

            if 'seen_gameids' in checkpoint:
                self.seen_gameids = checkpoint['seen_gameids'].to(self.ddp_util.rank)
                seen_gameids_set = set(torch.nonzero(checkpoint['seen_gameids']).flatten().cpu().tolist())
                remaining_gameids = list(set(self.data.train_gameids) - seen_gameids_set)
                self.data.gameids = remaining_gameids
                self.data.train_gameids = remaining_gameids

                logging.info(f'There are {len(remaining_gameids)} remaining gameids to train on.')

            if 'saved_chkpts' in checkpoint:
                self.saved_chkpts = checkpoint['saved_chkpts'].to(self.ddp_util.rank)
                
            self.logger.grad_steps = checkpoint['gradient_steps']
            self.logger.tot_samples = checkpoint['num_samples']
            self.train_min_loss = checkpoint['train_min_loss']
        elif self.cfg.setup.wandb_load_dir:
            logging.info(f"Starting run from {self.cfg.setup.wandb_load_dir}")
            checkpoint = load_checkpoint(self.cfg.setup.model_load_name, savedir=self.cfg.setup.wandb_load_dir)

            # Load weights for agents
            self._load_weights(checkpoint['model_state_dict'])
            logging.info(f"Loaded weights!")
        else:
            logging.info("No wandb_id specified to resume from.")

    def _get_model(self):
        raise NotImplementedError()