"""
Implement trainer on a single host (works with multile devices).
Implementation relies on automatic parallelization during compilation using jit.
Advatages:
 - Simple code: like for a single device setting
 - No need for per deivice explicit-collectives programming (like all_gather).
 - works on multiple devices if they are on the same host
Disadvantages:
 - Does not work in multihost environment
 - Less control over parallelization => Can be less optimal for custom operations
"""

from typing import Any, Dict, Tuple, Callable, Iterable, Union
from functools import partial
import os
from tqdm.auto import tqdm
import numpy as np
import torch
import transformers
from accelerate import Accelerator
import wandb
from torch.utils.data import DataLoader
import wandb.apis
import wandb.sdk
from accelerate import DistributedDataParallelKwargs as DDPK
from datetime import timedelta
from accelerate import InitProcessGroupKwargs
from latte_trans.evals.base_pt import TorchEvaluator

WandbRun = Union[wandb.apis.public.Run, None]


def get_scheduler(
    config: Any, total_steps: int, optim: torch.optim.Optimizer, num_processes
) -> torch.optim.lr_scheduler.LambdaLR:
    total_steps = total_steps * num_processes
    warmup_steps = 0
    if config.warmup > 0:
        warmup_steps = config.warmup
        warmup_steps = warmup_steps * num_processes
    else:
        # it is 0 for no warmup
        warmup_steps = int(config.warmup_pc * total_steps)

    # total_steps =  total_steps*num_processes - alreay multiplied by num proc in total steps

    if config.lr_decay_fn == "cosine":
        print(f"total = {total_steps}, warmup = {warmup_steps}")
        lr_scheduler = transformers.get_cosine_schedule_with_warmup(
            optimizer=optim,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
            num_cycles=0.5,
        )
    elif config.lr_decay_fn == "linear":
        lr_scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer=optim,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,  # - warmup_steps,
        )
    else:
        lr_scheduler = transformers.get_constant_schedule_with_warmup(
            optimizer=optim, num_warmup_steps=warmup_steps
        )

    return lr_scheduler


def prepare_optimizer(config: Any, model: torch.nn.Module) -> torch.optim.Optimizer:
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=config.lr, weight_decay=config.weight_decay
    )
    return optimizer


class Trainer:
    """Simple trainer on a single Host"""

    def __init__(
        self,
        config: Any,
        out_dir: str,
        model: torch.nn.Module,
        train_data: Iterable = None,
        train_dl: Iterable = None,
        data_collator: Callable = None,
        evaluator: TorchEvaluator = None,  # compute_metrics functions
        test_evaluator: TorchEvaluator = None,
        model_inputs_orded: Tuple[str] = ("input_ids", "labels"),
    ) -> None:
        self.config = config
        self._out_dir = out_dir
        self._model = model
        self.train_data = train_data
        self.data_collator = data_collator
        self.eval_steps = self.config.eval_steps
        self.max_checkpoints = self.config.max_checkpoints
        self._evaluator = evaluator
        self._test_evaluator = test_evaluator
        self.model_inputs_orded = model_inputs_orded

        self._eval_metrics = []

        if train_dl is None:
            self.train_dl = DataLoader(
                self.train_data,
                batch_size=self.config.batch_size,
                shuffle=self.config.shuffle_train,
                collate_fn=self.data_collator,
                drop_last=True,
            )
        else:
            self.train_dl = train_dl

        if self.eval_steps == 0:
            self.eval_steps = int(
                np.ceil(len(self.train_data) / self.config.batch_size)
            )

        kwargs = DDPK(find_unused_parameters=True)
        process_group_kwargs = InitProcessGroupKwargs(
            timeout=timedelta(seconds=5400)
        )  # 1.5 hours

        self._accelerator = Accelerator(
            mixed_precision=config.mixed_precision,
            gradient_accumulation_steps=config.grad_accumulation_steps,
            project_dir=out_dir,
            kwargs_handlers=[kwargs, process_group_kwargs],
        )
        self.total_steps = self.calc_total_steps()
        self._optimizer = prepare_optimizer(self.config, model)
        self._lr_scheduler = get_scheduler(
            config, self.total_steps, self._optimizer, self._accelerator.num_processes
        )

        self.wandb_run = None
        if self._accelerator.is_main_process:
            self.wandb_run = self.set_logger()
            print(model)

    def set_logger(self):
        wandb_run = None
        # configure wandb logs
        if self.config.wandb_log:
            resume = False
            run_id = None
            if not self.config.check_path is None:
                resume = "must"
                run_id = self.config.run_id
            wandb_run = wandb.init(
                project=self.config.project,
                entity=self.config.entity,
                name=self.config.name,
                dir=self._out_dir,
                config=self.config,
                id=run_id,
                resume=resume,
            )
            wandb_run = wandb_run

        return wandb_run

    def sample_data(self) -> Tuple[torch.tensor]:
        data = next(iter(self.train_dl))
        # TODO: Investigate why dictionary does not work for jit
        data = tuple([data[k] for k in self.model_inputs_orded])
        return data

    def safe_wandb_log(self, log_data: Dict[str, Any]):
        if (self.wandb_run is not None) and self._accelerator.is_main_process:
            generations = log_data.pop("generations", None)
            self.wandb_run.log(log_data)
            if generations is not None:
                colums = ["Prompt", "Expected", "Generation"]
                gen_table = wandb.Table(columns=colums, data=generations)
                self.wandb_run.log({"gen_table": gen_table})

    def calc_total_steps(self) -> int:
        epochs = self.config.epochs
        if self.config.train_steps is None:
            total_steps = epochs * (
                np.ceil(
                    len(self.train_data)
                    / (self.config.batch_size * self._accelerator.num_processes)
                )
            )
            total_steps = int(total_steps)
        else:
            total_steps = self.config.train_steps
        return total_steps  # *self._accelerator.num_processes

    def prepare_train(self):
        self.train_dl, val_dl, self._model, self._optimizer, self._lr_scheduler = (
            self._accelerator.prepare(
                self.train_dl,
                self._evaluator.get_valdl(),
                self._model,
                self._optimizer,
                self._lr_scheduler,
            )
        )
        self._evaluator.set_valdl(val_dl)

    def train(self, checkpoint_path: str = None) -> torch.nn.Module:
        start_it = 0
        self.prepare_train()
        if checkpoint_path is not None:
            self._accelerator.print("Loading state")
            file_path = os.path.join(checkpoint_path, "checkpoint.pth")
            self._accelerator.load_state(file_path)
            # start_it = self._accelerator.step
            start_it = 47000
            self._accelerator.print(f"Start it: {start_it}")

        self._accelerator.print(
            "Trainer total steps: ", self.total_steps, "Start it: ", start_it
        )
        model = self._train(total_steps=self.total_steps, start_it=start_it)

        # list of dicts to dicts of list
        metrics = self._eval_metrics
        if len(self._eval_metrics) > 0:
            metrics = {
                key: [i[key] for i in self._eval_metrics]
                for key in self._eval_metrics[0]
            }
        return metrics, model

    def eval_step(
        self,
        batch: dict[str, torch.tensor],
    ) -> Tuple[torch.tensor, Dict[str, torch.tensor]]:
        """
        Places data on correct device and calls the model on the batch
        """
        inputs = tuple([batch[k] for k in self.model_inputs_orded])
        with torch.no_grad():
            outputs = self._model(*inputs)
        # outputs = self._accelerator.gather(outputs)
        outputs, labels = self._accelerator.gather_for_metrics(
            (outputs, batch["labels"])
        )
        outputs = {k: outputs[k].detach().cpu().numpy() for k in outputs.keys()}
        return labels.detach().cpu().numpy(), outputs

    def train_step(self, batch: dict[str, torch.tensor]):
        inputs = tuple([batch[k] for k in self.model_inputs_orded])
        outputs = self._model(*inputs)
        loss = outputs["loss"]
        self._accelerator.backward(loss)
        self._optimizer.step()
        self._lr_scheduler.step()
        self._optimizer.zero_grad()
        return loss

    def _train(
        self,
        total_steps: int,
        start_it: int = 0,
    ) -> torch.nn.Module:
        self._model.train()
        param_count = sum(
            p.numel() for p in self._model.parameters() if p.requires_grad
        )
        self._accelerator.print(f"Number of parameters: {param_count / 1000000} M")

        it = start_it
        progress_bar = tqdm(
            range(total_steps),
            position=it,
            leave=True,
            disable=not self._accelerator.is_local_main_process,
        )
        all_scores = []
        losses = []
        while True:
            train_loss = []
            for batch in self.train_dl:
                loss = self.train_step(batch)
                loss = self._accelerator.gather(loss).mean().item()
                train_loss.append(loss)

                if (it > 0 and it % self.eval_steps == 0) or (it >= total_steps):
                    scores = {}
                    self._model.eval()
                    eval_fn = partial(self.eval_step)
                    # compute scores on test
                    if self._evaluator is not None:
                        scores = self._evaluator.evaluate(
                            trainer_eval_fn=eval_fn,
                            prefix="eval_",
                            accelerator=self._accelerator,
                        )
                    # compute scores on test
                    if self._test_evaluator is not None:
                        test_scores = self._test_evaluator.evaluate(
                            trainer_eval_fn=eval_fn,
                            prefix="test_",
                            accelerator=self._accelerator,
                        )
                        scores.update(test_scores)

                    scores["train_loss"] = np.mean(train_loss)
                    scores["learning_rate"] = self._lr_scheduler.get_last_lr()[-1]

                    # scores["#Toks"] = (
                    #     self.config.max_seq_len * self.config.batch_size * it
                    # )
                    # scores["Epoch"] = it // len(self.train_dl)

                    self._accelerator.print("Train Loss ", scores["train_loss"])
                    self._accelerator.print("Evaluation scores: ", scores)
                    self.safe_wandb_log(scores)
                    all_scores.append(scores)
                    # save checkpoint
                    file_path = os.path.join(self._out_dir, "checkpoint.pth")
                    self._accelerator.save_state(
                        file_path, {"step": self._accelerator.step}
                    )
                    self._model.train()
                it += 1
                progress_bar.update(1)
                if it >= total_steps:
                    break
            losses.append(np.mean(train_loss))
            if it >= total_steps:
                break

        self._accelerator.wait_for_everyone()
        unwrapped_model = self._accelerator.unwrap_model(self._model)
        return unwrapped_model
