import os
import torch
import logging
import sys

from tqdm import tqdm
from typing import Dict
from transformers import PreTrainedTokenizerBase
from torch.utils.tensorboard import SummaryWriter

from CoLM import utils
from CoLM.data import LMCorpusMemmapDataset, build_dataloader
from CoLM.distributed_utils import build_distributed_configuration
from CoLM.option import TrainArg
from CoLM.criterion import build_criterion
from CoLM.optim import build_optimizer, build_scheduler


logging.basicConfig(
    level=logging.INFO, 
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    stream=sys.stdout,
)
logger = logging.getLogger("Training")


class Trainer:

    def __init__(
        self,
        *,
        args: TrainArg,
        model: torch.nn.Module = None,
        tokenizer: PreTrainedTokenizerBase = None,
    ):
        self.args = args
        self.model = model
        self.tokenizer = tokenizer
        self.dist = build_distributed_configuration(args)

        self.datasets = self.make_lm_datasets()
        self.criterion = self.make_criterion()
        self.train_loader, self.valid_loader = self.make_dataloader()
        self.optimizer, self.lr_schduler = self.make_optimizer_and_scheduler()

        self.writer = SummaryWriter(log_dir=args.save_dir) if self.is_master else None

        self.prepare()
        self.print_status()

    def print_status(self):
        logger.info(self.model)
        logger.info(f"The trainable parameters of model is {utils.count_trainable_parameters(self.model)}")
        logger.info(f"The configuration of model is {self.model.config}")
    
    def prepare(self):
        self.model, self.optimizer, self.lr_schduler, self.train_loader, self.valid_loader = \
            self.dist.prepare(
                self.model, self.optimizer, self.lr_schduler, self.train_loader, self.valid_loader
            )
        from transformers.debug_utils import DebugUnderflowOverflow
        debug_overflow = DebugUnderflowOverflow(self.model)

    def make_criterion(self):
        return build_criterion(name=self.args.criterion)

    def make_optimizer_and_scheduler(self):
        optimizer = build_optimizer(
            params=utils.get_trainable_parameters(self.model),
            args=self.args,
            is_deepspeed=self.dist.use_deepseed,
        )
        lr_schduler = build_scheduler(
            args=self.args,
            optimizer=optimizer,
            is_deepspeed=self.dist.use_deepseed,
        )
        return optimizer, lr_schduler

    def make_lm_datasets(self):
        """Build LM dataset for pre-training."""
        logger.info("Loading datasets...")

        datasets = {
            'train': LMCorpusMemmapDataset.build_dataset(
                self.args.train_split,
                self.args,
                self.vocab_size
            ),
            'valid': LMCorpusMemmapDataset.build_dataset(
                self.args.valid_split,
                self.args,
                self.vocab_size
            ),
        }
        
        logger.info("Dataset loaded.")
        return datasets

    def make_dataloader(self):
        """Build DataLoader for pre-training."""
        train_loader = build_dataloader(
            dataset=self.train_dataset,
            args=self.args,
            batch_size=self.args.batch_size,
            eos_token=self.tokenizer.bos_token_id,
        )
        valid_loader = build_dataloader(
            dataset=self.valid_dataset,
            args=self.args,
            batch_size=self.args.eval_batch_size,
            eos_token=self.tokenizer.bos_token_id,
        )
        return train_loader, valid_loader

    @property
    def train_dataset(self):
        return self.datasets['train']
        
    @property
    def valid_dataset(self):
        return self.datasets['valid']

    @property
    def vocab_size(self):
        return len(self.tokenizer)

    @property
    def is_master(self):
        return self.dist.is_master

    def save_checkpoint(self, prefix: str = ""):
        if not os.path.exists(self.args.save_dir):
            os.makedirs(self.args.save_dir)

        checkpoint_dir = os.path.join(self.args.save_dir, prefix)
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        
        checkpoint_path = os.path.join(checkpoint_dir, "colm.pth")
        
        torch.save({
            'model_state_dict': self.model.state_dict(),
        }, checkpoint_path)

    def set_epoch(self, epoch: int = 0):
        if hasattr(self.train_loader.sampler, "set_epoch"):
            self.train_loader.sampler.set_epoch(epoch)
        else:
            self.train_loader.set_epoch(epoch)

    def train(self):
        max_epoch, num_steps, total_loss, total_num = 10000, 0, 0, 0

        pbar = tqdm(range(self.args.max_update), disable=not self.is_master)
        self.model.train()
        self.optimizer.zero_grad()

        for i in range(max_epoch):
            self.set_epoch(i)
            
            for batch in self.train_loader:
                outputs = self.train_step(batch)
                num_steps += 1

                total_loss += outputs["loss"].detach().item() * outputs["tokens_num"]
                total_num += outputs["tokens_num"]

                loss = total_loss / total_num
                outputs["loss"] = loss
                if self.is_master:
                    self.writer.add_scalar("train/loss", loss, num_steps)

                if self.dist.accelerator.sync_gradients:
                    pbar.set_postfix(self.criterion.logging_outputs(outputs))
                    pbar.update(1)

                if num_steps % self.args.save_every_N_steps == 0:
                    self.validate()
                    if self.is_master:
                        self.save_checkpoint(prefix=f"step_{num_steps}")

                if num_steps >= self.args.max_update:
                    break
            
            if num_steps >= self.args.max_update:
                break

        self.validate()
        if self.is_master:
            self.writer.flush()
            self.writer.close()
            self.save_checkpoint(prefix="final")

    def train_step(self, batch: Dict[str, torch.Tensor]):
        with self.dist.accelerator.autocast():
            with self.dist.accelerator.accumulate(self.model):
                outputs = self.criterion(self.model, batch)

                self.dist.accelerator.backward(outputs["loss"])
                if self.dist.accelerator.sync_gradients and self.args.max_grad_norm > 0.0:
                    self.dist.accelerator.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                # add clip grad
                self.optimizer.step()
                self.lr_schduler.step()
                self.optimizer.zero_grad()
        return outputs

    def validate(self):
        self.model.eval()

        total_loss, total_num = 0, 0
        with torch.no_grad():
            with tqdm(self.valid_loader, disable=not self.is_master) as pbar:
                for batch in self.valid_loader:
                    outputs = self.criterion(self.model, batch)

                    total_loss += outputs["loss"].detach().item() * outputs["tokens_num"]
                    total_num += outputs["tokens_num"]

                    loss = total_loss / total_num
                    outputs["loss"] = loss

                    pbar.set_postfix(self.criterion.logging_outputs(outputs))
                    pbar.update(1)