import os
import torch
import logging
import sys
import math

from LLMProxy import utils
from LLMProxy.data import (
    build_dataset,
    build_dataloader,
    get_tokenizer,
)
from LLMProxy.models import build_model
from LLMProxy.criterion import build_criterion
from LLMProxy.optim import build_optimizer, build_scheduler
from LLMProxy.option import TrainArg, ModelArg, DistArg
from LLMProxy.distributed_utils import (
    is_master,
    create_ddp_model
)
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import DummyOptim, DummyScheduler
from typing import List, Dict
from tqdm import tqdm

from transformers.modeling_utils import unwrap_model

logging.basicConfig(
    level=logging.INFO, 
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    stream=sys.stdout,
)
logger = logging.getLogger("Training")


class Trainer:

    def __init__(
        self,
        *,
        train_args: TrainArg,
        model_args: ModelArg,
        dist_args: DistArg,
    ):
        self.args = train_args

        # Model setting
        self.model = build_model(model_args)
        self.tokenizer = get_tokenizer(tokenizer=self.args.tokenizer, token=model_args.auth)
        self.criterion = build_criterion(name=self.args.criterion)

        # System setting
        self.is_master = is_master(dist_args)
        self.precision = utils.device(self.args.precision)

        # Dataset setting
        self.train_dataset = build_dataset(self.args.train_split, self.args)
        self.valid_dataset = build_dataset(self.args.valid_split, self.args)

        self.train_loader = build_dataloader(
            dataset=self.train_dataset,
            args=train_args,
            batch_size=self.args.batch_size,
            eos_token=self.tokenizer.bos_token_id,
        )
        self.valid_loader = build_dataloader(
            dataset=self.valid_dataset,
            args=train_args,
            batch_size=self.args.eval_batch_size,
            eos_token=self.tokenizer.bos_token_id,
        )
        
        self.optimizer = None
        self.lr_scheduler = None
        
        # Optimization & Distribution setting
        # optimizer = build_optimizer(utils.get_trainable_parameters(self.model), self.args)
        # lr_scheduler = build_scheduler(self.args, optimizer)        
        # self.updater = build_updater(self.args, dist_args, optimizer, lr_scheduler)
        # self.model = self.updater.prepare_model(self.model, dist_args)

    def set_epoch(self, epoch: int = 0):
        if hasattr(self.train_loader.sampler, "set_epoch"):
            self.train_loader.sampler.set_epoch(epoch)
        else:
            self.train_loader.set_epoch(epoch)

    def train(self):
        max_epoch = 10 ** 6
        if self.args.max_epoch != -1:
            max_epoch = self.args.max_epoch

        self.num_steps = 0
        self.model.train()

        for i in range(max_epoch):
            self.set_epoch(i)
            self.train_epoch(i)

            if self.num_steps > self.args.max_update:
                break
                 
    def train_epoch(self, epoch_idx: int = 0):
        total_loss, total_num = 0, 0
        total_loss_std, total_loss_sparse = 0, 0
        self.optimizer.zero_grad()

        with tqdm(self.train_loader, disable=not self.is_master) as pbar:
            for batch in self.train_loader:
                outputs = self.train_step(batch)

                if "overflow" in outputs:
                    if outputs["overflow"]:
                        # ignore this batch
                        pbar.update(1)
                        continue

                self.num_steps += 1

                if self.args.save_every_N_steps != -1 and self.num_steps % self.args.save_every_N_steps == 0:
                    if self.is_master:
                        self.save_checkpoint("step_{}".format(self.num_steps))                

                total_loss_sparse += outputs["loss_sparse"].detach().item() * outputs["tokens_num"]
                total_loss_std += outputs["loss_std"].detach().item() * outputs["tokens_num"]
                total_loss += outputs["loss"].detach().item() * outputs["tokens_num"]

                total_num += outputs["tokens_num"]

                loss_std = total_loss_std / total_num
                loss_sparse = total_loss_sparse / total_num
                loss = total_loss / total_num

                pbar.set_postfix({
                    "loss_sparse": '{:.2f}'.format(loss_sparse),
                    "ppl_sparse": '{:.2f}'.format(math.exp(loss_sparse)),
                    "loss_std": '{:.2f}'.format(loss_std),
                    "ppl_std": '{:.2f}'.format(math.exp(loss_std)),
                    "loss_total": '{:.2f}'.format(loss),
                    # "ppl": '{:.2f}'.format(math.exp(loss)),
                })

                pbar.update(1)

                if self.num_steps > self.args.max_update:
                    break

        if self.is_master:
            self.save_checkpoint("epoch_{}".format(epoch_idx))

    def train_step(self, batch: Dict[str, torch.Tensor]):
        self.model.train()
        
        with torch.cuda.amp.autocast(dtype=self.precision):
            batch = utils.move_to_cuda(batch, device=self.model.device)         
            outputs = self.criterion(self.model, batch)
        
        self.optimize(outputs["loss"])
        return outputs

    def optimize(self, loss: torch.Tensor):
        raise NotImplementedError

    def validate(self):
        self.model.eval()

        total_loss, total_num = 0, 0
        total_loss_std, total_loss_sparse = 0, 0
        with torch.no_grad():
            with tqdm(self.valid_loader, disable=not self.is_master) as pbar:
                for batch in self.valid_loader:
                    batch = utils.move_to_cuda(batch, device=self.model.device)
                    
                    outputs = self.criterion(self.model, batch)

                    total_loss_sparse += outputs["loss_sparse"].detach().item() * outputs["tokens_num"]
                    total_loss_std += outputs["loss_std"].detach().item() * outputs["tokens_num"]
                    total_loss += outputs["loss"].detach().item() * outputs["tokens_num"]

                    total_num += outputs["tokens_num"]

                    loss_std = total_loss_std / total_num
                    loss_sparse = total_loss_sparse / total_num
                    loss = total_loss / total_num

                    pbar.set_postfix({
                        "loss_sparse": '{:.2f}'.format(loss_sparse),
                        "ppl_sparse": '{:.2f}'.format(math.exp(loss_sparse)),
                        "loss_std": '{:.2f}'.format(loss_std),
                        "ppl_std": '{:.2f}'.format(math.exp(loss_std)),
                        "loss_total": '{:.2f}'.format(loss),
                        # "ppl": '{:.2f}'.format(math.exp(loss)),
                    })
                    pbar.update(1)

    def save_checkpoint(self, prefix: str = ""):
        if not os.path.exists(self.args.save_dir):
            os.makedirs(self.args.save_dir)

        checkpoint_dir = os.path.join(self.args.save_dir, prefix)
        self.model.save_pretrained(checkpoint_dir)


class DDPTrainer(Trainer):

    def __init__(
        self,
        *,
        train_args: TrainArg,
        model_args: ModelArg,
        dist_args: DistArg,
    ):
        super().__init__(train_args=train_args, model_args=model_args, dist_args=dist_args)

        self.model = create_ddp_model(self.model.cuda(), dist_args)
        self.optimizer = build_optimizer(utils.get_trainable_parameters(self.model), self.args)
        self.lr_scheduler = build_scheduler(self.args, self.optimizer)

        if self.precision == torch.float16:
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None
        self._step = 0
        self.accumulation_step = self.args.accumulation_step

    def optimize(self, loss: torch.Tensor):
        self.scaler.scale(loss).backward() if self.scaler else loss.backward()
        self._step += 1

        if self._step == self.accumulation_step:
            self._step = 0
            self.scaler.step(self.optimizer) if self.scaler else self.optimizer.step()
            self.optimizer.zero_grad()

            if self.scaler:
                scale = self.scaler.get_scale()
                self.scaler.update()

                if scale == self.scaler.get_scale():
                    self.lr_scheduler.step()
            else:
                self.lr_scheduler.step()


class AccelerateTrainer(Trainer):
    
    def __init__(
        self,
        *,
        train_args: TrainArg,
        model_args: ModelArg,
        dist_args: DistArg,
    ):
        super().__init__(train_args=train_args, model_args=model_args, dist_args=dist_args)

        self.accelerator = Accelerator(mixed_precision=self.args.precision, gradient_accumulation_steps=self.args.accumulation_step)
        
        optimizer = build_optimizer(utils.get_trainable_parameters(self.model), self.args)
        lr_scheduler = build_scheduler(self.args, optimizer)

        self.model, self.train_loader, self.valid_loader, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
            self.model, self.train_loader, self.valid_loader, optimizer, lr_scheduler
        )

    def train_step(self, batch: Dict[str, torch.Tensor]):

        with self.accelerator.autocast():
            with self.accelerator.accumulate(self.model):
                outputs = self.criterion(self.model, batch)

                self.accelerator.backward(outputs["loss"])
                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()

        return outputs

    def validate(self):
        self.model.eval()

        total_loss, total_acc, total_num = 0, 0, 0
        total_loss_std, total_acc_std, total_num_std = 0, 0, 0
        
        with self.accelerator.autocast():
            with torch.no_grad():
                with tqdm(self.valid_loader, disable=not self.is_master) as pbar:
                    for batch in self.valid_loader:
                        outputs = self.criterion(self.model, batch)

                        total_loss += outputs["loss_sparse"].detach().item() * outputs["tokens_num"]
                        total_num += outputs["tokens_num"]
                        total_acc += outputs["acc_sparse"]

                        total_loss_std += outputs["loss_std"].detach().item() * outputs["tokens_num_std"]
                        total_num_std += outputs["tokens_num_std"]
                        total_acc_std += outputs["acc_std"]

                        loss = total_loss / total_num
                        acc = total_acc / total_num

                        loss_std = total_loss_std / total_num_std
                        acc_std = total_acc_std / total_num_std

                        pbar.set_postfix({
                            "loss": '{:2f}'.format(loss),
                            "ppl": '{:.2f}'.format(math.exp(loss)),
                            "acc": '{:.2f}'.format(acc),
                            "loss_std": '{:2f}'.format(loss_std),
                            "ppl_std": '{:.2f}'.format(math.exp(loss_std)),
                            "acc_std": '{:.2f}'.format(acc_std),
                        })
                        pbar.update(1)



    def validate_with_early_stop(self, stop_step: 100):
        self.model.eval()

        total_loss, total_acc, total_num = 0, 0, 0
        total_loss_std, total_acc_std, total_num_std = 0, 0, 0
        max_eval_steps = stop_step
        
        with self.accelerator.autocast():
            with torch.no_grad():
                with tqdm(self.valid_loader, disable=not self.is_master) as pbar:
                    for batch in self.valid_loader:
                        outputs = self.criterion(self.model, batch)

                        total_loss += outputs["loss_sparse"].detach().item() * outputs["tokens_num"]
                        total_num += outputs["tokens_num"]
                        total_acc += outputs["acc_sparse"]

                        total_loss_std += outputs["loss_std"].detach().item() * outputs["tokens_num_std"]
                        total_num_std += outputs["tokens_num_std"]
                        total_acc_std += outputs["acc_std"]

                        loss = total_loss / total_num
                        acc = total_acc / total_num

                        loss_std = total_loss_std / total_num_std
                        acc_std = total_acc_std / total_num_std

                        pbar.set_postfix({
                            "loss": '{:.2f}'.format(loss),
                            "ppl": '{:.2f}'.format(math.exp(loss)),
                            "acc": '{:.2f}'.format(acc),
                            "loss_std": '{:.2f}'.format(loss_std),
                            "ppl_std": '{:.2f}'.format(math.exp(loss_std)),
                            "acc_std": '{:.2f}'.format(acc_std),
                        })
                        pbar.update(1)

                        if pbar.n > max_eval_steps:
                            break


class DeepSpeedTrainer(Trainer):

    def __init__(
        self,
        *,
        train_args: TrainArg,
        model_args: ModelArg,
        dist_args: DistArg,
    ):
        super().__init__(train_args=train_args, model_args=model_args, dist_args=dist_args)

        optimizer = DummyOptim(
            params=utils.get_trainable_parameters(self.model), lr=self.args.learning_rate,
        )
        lr_scheduler = DummyScheduler(
            optimizer, total_num_steps=self.args.max_update, warmup_num_steps=self.args.num_warmup_steps,
        )
        deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=dist_args.deepspeed_config)
        # self.accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)
        self.accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)

        self.model, self.train_loader, self.valid_loader, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
            self.model, self.train_loader, self.valid_loader, optimizer, lr_scheduler
        )

    def train_step(self, batch: Dict[str, torch.Tensor]):
        is_overflow = False

        with self.accelerator.autocast():
            with self.accelerator.accumulate(self.model):
                outputs = self.criterion(self.model, batch)

                self.accelerator.backward(outputs["loss"])
                self.optimizer.step()
                is_overflow = self.optimizer.optimizer.overflow
                self.lr_scheduler.step()
                self.optimizer.zero_grad()

        outputs["overflow"] = is_overflow
        return outputs


    
