import time
import json
from pathlib import Path
from typing import Optional
from datetime import datetime
from shutil import copyfile

import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch import nn, Tensor
import torch.nn.functional as F
import accelerate
from accelerate import DistributedType
from accelerate.optimizer import AcceleratedOptimizer
from accelerate.scheduler import AcceleratedScheduler
from transformers import PreTrainedTokenizerBase

from arguments import Args
from utils import (
    inspect_params,
    get_non_embed_param_count,
    get_param_count,
    Logger,
)
from data.pretraining_dataset_jeeves import (
    MixedDataset,
    PackedMixedDataset,
)


def data_collator(examples, return_attn_args: bool = False) -> dict:
    """
    Args:
        return_attn_args (bool): if True, each batch also contains:
            - position_ids
            - cu_seqlens
            - max_seqlen
    """
    # print("### [collate_fn]")
    inputs = torch.stack([eg["input_ids"] for eg in examples])
    targets = torch.stack([eg["labels"] for eg in examples])
    seq_lens = torch.tensor([eg["seq_len"] for eg in examples])
    # task_ids = torch.stack([eg["task_ids"] for eg in examples])
    inputs = inputs.to(dtype=torch.long)
    targets = targets.to(dtype=torch.long)
    batch = dict(
        input_ids=inputs,
        labels=targets,
        seq_lens=seq_lens,
    )

    # task_names = [eg["task_names"] for eg in examples]
    # indexes = [eg["indexes"] for eg in examples]
    if return_attn_args:
        position_ids = torch.stack([eg["position_ids"] for eg in examples])
        cu_seqlens = torch.stack([eg["cu_seqlens"] for eg in examples])
        max_seqlen = torch.stack([eg["max_seqlen"] for eg in examples])
        batch.update(
            dict(
                position_ids=position_ids,
                cu_seqlens=cu_seqlens,
                max_seqlen=max_seqlen,
            )
        )
    # print('#######################')
    # print(batch)
    return batch


class Trainer:
    def __init__(
        self,
        args: Args,
        model: nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        optimizer: Optional[torch.optim.Optimizer] = None,
        lr_scheduler = None,
        run_name: Optional[str] = None,
        device='cuda',
    ):
        self.args = args
        
        if run_name is None:
            run_name = self.get_run_name()
        
        self.output_dir = Path(self.args.output_dir) / self.args.model / run_name
        self.output_dir.mkdir(exist_ok=True, parents=True)
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.accelerator = self.get_accelerator()
        self.logger = Logger(accelerator=self.accelerator)
        self.device = device

        self.memory_states = None
        self.save_config(args)

    def save_config(self, args):
        self.output_dir.mkdir(exist_ok=True, parents=True)
        if args.model_config is None:
            copyfile(Path(args.pretrained_path) / 'config.json', self.output_dir / 'config.json')
        else:
            copyfile(args.model_config, self.output_dir / 'config.json')

    def get_batch_cache(self):
        return self.memory_states

    def get_training_dataloader(self, discard_leftover: bool = False):
        assert self.accelerator is not None, "Accelerator is None while trying to instantiate a training dataloader."
        mixed_indexed_dataset = MixedDataset(
            cfg_path=self.args.data_config,
            cfg_json_str=None,
            tokenizer=self.tokenizer,
            num_processes=self.accelerator.num_processes,
            process_index=self.accelerator.process_index,
            local_process_index=self.accelerator.local_process_index,
            max_length=self.args.max_length,
            nthreads=self.args.dataloader_num_threads,
            prefetch_slice=self.args.dataloader_prefetch,
            weight_by_size=True,
            discard_leftover=discard_leftover,
        )

        # `args.batch_size` specifies the minimum number of sequences
        # packed into each example.
        batched_dataset = PackedMixedDataset(
            mixed_indexed_dataset,
            # batch_size=self.args.packing_count,
            max_length=self.args.max_length,
            packing_count=self.args.packing_count,
            repeat_data=bool(self.args.repeat_data),
        )

        if self.args.one_sequence_batch:
            print("!!! USING ONE SEQUENCE BATCH")
            dataloader = DataLoader(
                batched_dataset,
                batch_size=1,
                collate_fn=lambda x: x,
                num_workers=self.args.dataloader_num_workers,
                prefetch_factor=self.args.dataloader_prefetch_factor,
            )
        else:
            print("!!! USING NORMAL BATCH")
            dataloader = DataLoader(
                batched_dataset,
                batch_size=self.args.batch_size,
                shuffle=False,
                collate_fn=data_collator,
                num_workers=self.args.dataloader_num_workers,
                prefetch_factor=self.args.dataloader_prefetch_factor,
            )
        return dataloader
    
    def get_validation_dataloader(self, batch_size: int = 1):
        mixed_indexed_dataset = MixedDataset(
            cfg_path=self.args.data_config,
            cfg_json_str=None,
            tokenizer=self.tokenizer,
            num_processes=self.accelerator.num_processes,
            process_index=self.accelerator.process_index,
            local_process_index=self.accelerator.local_process_index,
            max_length=self.args.max_length,
            nthreads=self.args.dataloader_num_threads,
            prefetch_slice=self.args.dataloader_prefetch,
            weight_by_size=True,
        )

        batched_dataset = PackedMixedDataset(
            mixed_indexed_dataset,
            max_length=self.args.max_length,
            packing_count=1,
        )

        dataloader = DataLoader(
            batched_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=data_collator,
            num_workers=self.args.dataloader_num_workers,
            prefetch_factor=self.args.dataloader_prefetch_factor,
        )
        return dataloader

    def get_run_name(self) -> str:
        # Get the current time
        now = datetime.now()
        # Format the time as YYMMDDhhmmss
        time_str = now.strftime("%y%m%d%H%M%S")
        run_name = (
            f"{self.args.model}"
            f"_lr{self.args.lr}"
            f"_T{self.args.max_length}"
            f"_B{self.args.batch_size}"
            f"_GA{self.args.grad_accum}"
            f"_P{self.args.packing_count}"
            f"_SR{self.args.state_reset_interval}"
            f"_RD{self.args.repeat_data}"
            f"_RI{self.args.rand_init}"
            f"_{time_str}"
        )
        return run_name

    def get_accelerator(self, project_dir: str = "result"):
        print("Initting accelerator...")
        Path(project_dir).mkdir(exist_ok=True, parents=True)
        if self.args.tensorboard is None:
            logging_dir = Path(f"exp{self.args.exp_group}_{self.args.exp_name}_lr{self.args.lr}")
        else:
            logging_dir = Path(self.args.tensorboard) / self.get_run_name()

        project_config = accelerate.utils.ProjectConfiguration(
            project_dir=project_dir,
            logging_dir=str(logging_dir),
        )
        accelerator = accelerate.Accelerator(
            log_with="tensorboard",
            project_config=project_config,
            gradient_accumulation_steps=self.args.grad_accum,
            step_scheduler_with_optimizer=False,
        )
        accelerator.init_trackers(
            self.args.project_name,
            config=self.args.as_dict(),
        )
        return accelerator

    def compute_loss(
        self,
        logits: Tensor,
        labels: Tensor,
    ):
        loss_fct = nn.CrossEntropyLoss()
        # Enable model parallelism
        labels = labels.to(logits.device)
        labels = torch.cat(
            (labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)),
            1,
        )
        loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
        return loss

    def train_step(
        self,
        accelerator: accelerate.Accelerator,
        model: nn.Module,
        optimizer: AcceleratedOptimizer,
        lr_scheduler: AcceleratedScheduler,
        cur_step: int,
        batch,  # type: ignore
    ):

        if self.args.debug:
            self.logger.log(batch)
            self.logger.log_all_rank(batch)

        # batch: dict = batch[0]  # The dataloader somehow returns a list
        iter_start_time = time.time()
        input_ids: Tensor = batch["input_ids"].to(torch.long).cuda()
        labels: Tensor = batch["labels"].to(torch.long).cuda()
        # position_ids: Tensor = batch["position_ids"].cuda()
        # cu_seqlens: Tensor = batch["cu_seqlens"].cuda()
        # max_seqlen: Tensor = batch["max_seqlen"].cuda()

        if cur_step < self.args.reset_interval_warmup_steps * self.args.grad_accum:
            # Warmup cache reset interval
            p = cur_step / (self.args.reset_interval_warmup_steps * self.args.grad_accum)
            cache_reset_interval = max(1, int(p * self.args.state_reset_interval))
        else:
            cache_reset_interval = self.args.state_reset_interval

        # Handle memory resetting
        if cur_step % cache_reset_interval == 0:
            self.memory_states = None

        # if self.memory_states is not None:
        #     print("#######", self.memory_states[0].time_mix_state)

        # Forward pass
        # print('forward')
        forward_start_time = time.time()
        outputs = model(
            input_ids=input_ids,
            labels=labels,
            states=self.memory_states,
            grad_ckpt=self.args.grad_ckpt,
            # # For attention
            # cu_seqlens=cu_seqlens,
            # max_seqlen=max_seqlen,
            # position_ids=position_ids,
        )
        # loss: Tensor = outputs[0]
        logits: Tensor = outputs.logits
        loss = self.compute_loss(logits, labels)
        # print(logits)
        # exit()
        self.memory_states = outputs.states
        self.memory_states.detach()

        # assert outputs.loss is not None
        # loss: Tensor = outputs.loss
        assert loss is not None, "Loss is None!"
        forward_time = time.time() - forward_start_time
        cur_loss = loss.item()

        # accelerator.print(outputs.logits)
        # accelerator.print(outputs.loss)
        # if cur_step == 2:
        #     exit()

        if self.args.debug:
            self.accelerator.print(f"{cur_loss = }")
            self.accelerator.print("################################################")
            inspect_params(model, accelerator)
            self.accelerator.print("################################################")
            self.accelerator.print("################### BACKWARD ###################")
            self.accelerator.print("################################################")

        # Backward pass
        # Gradient accumulation is handled automatically by the accelerator.
        backward_start_time = time.time()
        accelerator.backward(loss)

        if self.args.debug:
            print("################################################")
            inspect_params(model, accelerator)

        # Gradient clipping, Deepspeed does gradient clipping automatically.
        if (
            accelerator.distributed_type != DistributedType.DEEPSPEED
            and accelerator.sync_gradients
        ):
            grad_norm: Tensor = accelerator.clip_grad_norm_(
                model.parameters(),
                max_norm=self.args.clip_grad,
            )  # type: ignore
            # print("grad_norm", grad_norm)
        else:
            grad_norm = torch.tensor(0.0)

        if self.args.debug:
            accelerator.print("################################################")
            accelerator.print(
                "################### optimizer.step() ###################"
            )
            accelerator.print("################################################")

        optimizer.step()
        if accelerator.sync_gradients and not accelerator.optimizer_step_was_skipped:
            # Somehow, Accelerate will call lr_scheduler step in every GPU, causing the
            # scheduler to be stepped Nx faster with N = the number of GPUs. This
            lr_scheduler.step()

        if self.args.debug:
            accelerator.print("### After lr_scheduler.step()")
            inspect_params(model, accelerator)
            accelerator.print("################################################")
            accelerator.print(input_ids.tolist())
            accelerator.print(labels.tolist())
            # exit()

        optimizer.zero_grad()
        backward_time = time.time() - backward_start_time

        # Log
        iter_time = time.time() - iter_start_time
        time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

        if self.args.log_interval != 0 and cur_step % self.args.log_interval == 0:

            mem_used = int(torch.cuda.memory_allocated() / 2 ** 20)

            if self.memory_states is None:
                state_mean = 0.0
            else:
                state_mean = self.memory_states[0][0].mean()

            print_info = f"it {cur_step} |"
            print_info += f" loss {cur_loss:.4f} |"
            print_info += f" lr {lr_scheduler.get_last_lr():.3e} |"
            if accelerator.distributed_type != DistributedType.DEEPSPEED:
                print_info += f" grad_norm {grad_norm:.3e} |"
            print_info += f" it_time {iter_time:.3f} |"
            print_info += f" fw_time {forward_time:.3f} |"
            print_info += f" bw_time {backward_time:.3f} |"
            print_info += f" mem_used {mem_used}MB |"
            print_info += f" s_mean {state_mean:.3e} |"

            # emb_avg = model.model.embeddings.weight[:, 0].mean()
            # print_info += f" emb_avg {emb_avg:.3e} |"

            # print_info += f" opt_skip {int(accelerator.optimizer_step_was_skipped)} |"
            self.logger.log(f"[{time_str}] {print_info}")
            # import subprocess
            # subprocess.run(["nvidia-smi"])

        # For tensorboard
        log_info = {
            "Loss/train": cur_loss,
            "Optimizer/lr": float(lr_scheduler.get_last_lr()),
            "Optimizer/grad_norm": grad_norm if grad_norm is not None else 0.0,
            "Iter/time": iter_time,
        }
        accelerator.log(log_info, step=cur_step)

        # Parameter inspection
        if (
            self.args.inspect_interval is not None
            and self.args.inspect_interval > 0
            and (cur_step + 1) % self.args.inspect_interval == 0
            and accelerator.process_index == 0
        ):
            inspect_params(model, accelerator)

    def save_args(self, output_path: Path):
        # ckpt_dir = Path(self.args.ckpt_dir)
        self.output_dir.mkdir(exist_ok=True, parents=True)
        json.dump(
            self.args.as_dict(),
            open(output_path, "w", encoding="utf8"),
            indent=4,
            sort_keys=True,
            ensure_ascii=False,
        )

    def handle_load(self, dataloader: DataLoader):
        # Handling loading from checkpoint (for resumption)
        if self.args.resume_path is not None and self.args.resume_path not in ["", 'none']:
            assert Path(
                self.args.resume_path
            ).exists(), (
                f"The path to load checkpoint from does not exist: {self.args.resume_path}"
            )
            self.accelerator.print(f"Loading checkpoint from: {self.args.resume_path}...")
            self.accelerator.load_state(self.args.resume_path, strict=False)
            print("Done loading checkpoint")
            # assert len(missing_keys) == 1 and missing_keys[0] == "model.lm_head.weight"
            self.model.tie_weights()

            if self.args.load_start_step is not None:
                self.accelerator.print(
                    f"Skipping first {self.args.load_start_step} batches..."
                )
                # Know that the checkpoint was saved at the end of the training step.
                # Here we cannot use `accelerator.skip_first_batches` because we
                # do not use their wrapped dataloader.
                # for i, _ in tqdm(enumerate(dataloader), total=self.args.load_start_step):
                #     if i == self.args.load_start_step:
                #         break
                self.start_step = self.args.load_start_step
            else:
                self.accelerator.print(
                    "A checkpoint was loaded, but load_start_step was not set,"
                    " so we will loop through the training data from step 0."
                    " Make sure this is what you want!"
                )
        else:
            self.start_step = 0

    def train(self):
        assert self.optimizer is not None
        assert self.lr_scheduler is not None, "LR scheduler is None."
        self.save_args(self.output_dir / 'args.json')
        self.model.train()

        self.logger.log(
            "Finished setting up tokenizer, model, optimizer, and LR scheduler."
        )
        self.logger.log("#####################################################")
        self.logger.log(f"#    # params: {get_param_count(self.model)}")
        self.logger.log(f"#    # non-e params: {get_non_embed_param_count(self.model)}")  # noqa
        self.logger.log("#####################################################")
        self.logger.log(json.dumps(self.args.as_dict(), indent=2, sort_keys=True))

        dataloader = self.get_training_dataloader()

        self.logger.log("### Preparing modules for accelerate...")
        # self.accelerator.deepspeed_config['train_micro_batch_size_per_gpu'] = 1
        if self.accelerator.state.deepspeed_plugin is not None:
            self.accelerator.state.deepspeed_plugin.deepspeed_config[
                "train_micro_batch_size_per_gpu"
            ] = 1
        _, model, optimizer, lr_scheduler = self.accelerator.prepare(
            dataloader,
            self.model,
            self.optimizer,
            self.lr_scheduler,
        )

        # Handling loading from checkpoint (for resumption)
        # This will load from a checkpoint, skip first `load_start_step`
        # batches, and set `self.start_step`.
        self.handle_load(dataloader)

        self.logger.log("Data loader preparation done.")
        try:
            if self.args.debug:
                print("#########################################################")
                inspect_params(model, self.accelerator)
                print("#########################################################")
            n_train_steps = self.args.n_train_steps * self.args.grad_accum
            save_interval = self.args.save_interval * self.args.grad_accum
            self.logger.log("====== Start training ======")
            self.logger.log(f"# train steps: {n_train_steps}")
            self.logger.log(f"# train batches: {self.args.n_train_steps}")
            self.logger.log(f"Save interval: {save_interval}")
            self.logger.log('============================')
            for cur_step, batch in enumerate(dataloader, start=self.start_step):
                # This line will handle gradient accum. automatically.
                with self.accelerator.accumulate(model):
                    self.train_step(
                        self.accelerator, model, optimizer, lr_scheduler, cur_step, batch
                    )

                    # Save checkpoints
                    if self.output_dir is not None and (cur_step + 1) % save_interval == 0:
                        # NOTE: The `cur_steps` refers to the number of observed batches,
                        # which is different from the number of optimization steps.
                        # Because batch size is grad_accum * packing_count.
                        # job_id = self.args.job_id
                        cur_opt_step = (cur_step + 1) // self.args.grad_accum
                        ckpt_dir = self.output_dir / f"ckpt_{cur_opt_step}"
                        self.save_ckpt(ckpt_dir)

                    # Check if we have reached a pre-defined number of
                    # training iterations.
                    if cur_step == n_train_steps and self.args.stop_when_end == 1:
                        self.accelerator.print(
                            f"Reached max training iter ({n_train_steps}), stopping..."
                        )
                        break

                # During debug (for NaN losses), we just exist after the first
                # iteration.
                if self.args.debug:
                    exit()
        except Exception as e:
            self.accelerator.print(f"Error in training loop: {e}")
            raise e

        self.accelerator.print("====== Training finished ======")
        self.accelerator.print(f"Saving to {self.output_dir}")
        self.accelerator.save_model(model, self.output_dir / 'ckpt_final')
        self.accelerator.end_training()

    def save_ckpt(self, ckpt_dir: Path):
        ckpt_dir.mkdir(exist_ok=True, parents=True)
        self.accelerator.print(f"Saving checkpoint to: {ckpt_dir}")
        self.accelerator.save_state(str(ckpt_dir))

        self.save_args(ckpt_dir / 'args.json')
        # with open(ckpt_dir / "_.success", "w", encoding="utf8") as fout:
        #     fout.write("TEST")

    @torch.no_grad()
    def evaluate_per_token(self, n_steps: int = 50, device='cuda'):
        self.model.eval()
        dataloader = self.get_training_dataloader(discard_leftover=True)
        all_losses = []
        all_input_ids = []
        max_length = self.args.max_length * self.args.packing_count
        
        print('====== Start evaluation of per-token loss ======')
        print(f"Device: {device}")
        print('================================================')
        
        for cur_step, batch in enumerate(dataloader):
            if cur_step == n_steps:
                break
            print(f"Step: {cur_step}/{n_steps}")
            # This line will handle gradient accum. automatically.
            input_ids: Tensor = batch["input_ids"][:, :-1].to(torch.long).to(device=device)  # (B T)
            all_input_ids.append(input_ids)
            labels: Tensor = batch["labels"][:, 1:].to(torch.long).to(device=device)  # (B, T)
            # print(input_ids.shape, labels.shape)
            outputs = self.model(input_ids=input_ids)
            logits = outputs.logits  # (B, T, V)
            B, T, V = logits.shape
            logits = logits.view(B * T, V)
            labels = labels.view(B * T)
            
            # Compute per-token loss
            loss = F.cross_entropy(logits, labels, reduction='none')
            per_token_loss = loss.view(B, T)
            all_losses.append(per_token_loss.to(dtype=torch.float32))

        stacked_input_ids = torch.cat(all_input_ids, dim=0)  # (N * B, T)
        stacked_losses = torch.cat(all_losses, dim=0)  # (N * B, T)
        per_token_loss = torch.zeros(stacked_losses.shape[1])  # (T)
        for i in range(len(per_token_loss)):
            token_losses = stacked_losses[:, i][stacked_input_ids[:, i] != 0]
            per_token_loss[i] = token_losses.mean(dim=0).squeeze().detach().cpu()  # (T)
        print(per_token_loss.shape)
        # breakpoint()
        return per_token_loss, stacked_input_ids
