"""The gpt2_train_fromhf.py can train a gpt-2 model with huggingface dataset specified. It can store checkpoints every CHECKPOINT_INTERVAL (default 100) batches of data. The checkpoint includes a 'block_no' that tracks which training data we have already covered. load_checkpoint function can recover the model as well as dataloader with only untrained parts.

Dataset is not shuffled in this version of implementation.
"""
import math
import os
import random
from collections import namedtuple
from typing import Any, List

import numpy as np
import torch
import torch.nn as nn
import wandb
import yaml
from datasets import Dataset, concatenate_datasets, load_dataset
from lstm import LSTMLayer
from mamba_ssm import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AdamW, AutoTokenizer, GPT2Config, GPT2LMHeadModel,
                          get_scheduler)
from wordlevel_tokenizer import TrainableWordTokenizer

yaml_path = 'configs/template_childes_lstm.yaml'
with open(yaml_path, 'r') as file:
    hyperparameters = yaml.safe_load(file)

seed = hyperparameters['training']['seed']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if using multi-GPU

# pull all global VARs
CHECKPOINT_DIR = hyperparameters['training']['checkpoint_dir']
OUTPUT_DIR = hyperparameters['training']['output_dir']
BATCH_SIZE = hyperparameters['training']['batch_size_per_gpu']
NAME = hyperparameters['model']['name']
MAX_LENGTH = hyperparameters['model']['tokenizer']['model_max_length']
CYCLE_MODE = hyperparameters['training']['training_cycle_unit']
CYCLE_VALUE = hyperparameters['training']['training_cycle_value']
LEARNING_RATE = hyperparameters['training']['lr']
WARMUP_STEPS = hyperparameters['training']['warmup_steps']
if 'weight_decay' in hyperparameters['training']:
    W_DECAY = hyperparameters['training']['weight_decay']
else:
    W_DECAY = 0.01
CHECKPOINT_INTERVAL = hyperparameters['training']['checkpoint_every']
DATASET_PATH = hyperparameters['training']['path']
IS_TRAIN_ALL = hyperparameters['training']['is_train_all']
TRAIN_SPLIT = hyperparameters['training']['train_split']
WANDB_PROJ = hyperparameters['training']['wandb_project']
WANDB_EXP_NAME = hyperparameters['training']['wandb_exp_name']
WANDB_LOG_EVERY = hyperparameters['training']['wandb_log_every']
TOKENIZER_TYPE = hyperparameters['model']['tokenizer']['tokenizer_type']
VOCAB_FILE = hyperparameters['model']['tokenizer']['vocab_file']


class GPT2Embeddings(nn.Module):
    """Embedding layer of LM."""
    def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
                 word_embed_proj_dim=None, device=None, dtype=None):
        """Initalize GPT2Embedding Layer."""
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if word_embed_proj_dim is None:
            self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
                                                **factory_kwargs)
            self.project_in = None
        else:
            self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
                                                padding_idx=padding_idx, **factory_kwargs)
            self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
                                        **factory_kwargs)
        self.max_position_embeddings = max_position_embeddings
        if self.max_position_embeddings > 0:
            self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
                                                    **factory_kwargs)

    def forward(self, input_ids, position_ids=None):
        """Forward method.

        input_ids: (batch, seqlen)
        position_ids: (batch, seqlen)
        """
        batch_size, seqlen = input_ids.shape
        embeddings = self.word_embeddings(input_ids)
        if self.project_in is not None:
            embeddings = self.project_in(embeddings)
        if self.max_position_embeddings > 0:
            if position_ids is None:
                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = embeddings + position_embeddings
        return embeddings


class SimpleLMHeadModelNoFFN(nn.Module):
    """LM head for LSTM and Mamba."""
    def __init__(self, d_model: int, vocab_size: int,
                 layer=LSTMLayer, max_position_embeddings=-1,
                 n_layer=1,
                 embed_dropout: float = 0.1,
                 pad_vocab_size_multiple: int = 1,
                 device=None, dtype=None, **kwargs) -> None:
        """Initalize LM head."""
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()

        self.vocab_size = vocab_size

        self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings, **factory_kwargs)
        self.embed_dropout = nn.Dropout(embed_dropout)

        self.mixer = nn.ModuleList([layer(d_model, device=device) for _ in range(n_layer)])

        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
        self.tie_weights()
        self.device = device
        # nn.init.normal_(self.embeddings.word_embeddings.weight, std=0.02)

    def tie_weights(self):
        """Ties lm head and mebeeding weights."""
        self.lm_head.weight = self.embeddings.word_embeddings.weight

    def forward(self, input_ids, position_ids=None, state=None, return_hidden_outputs=False, labels=None, **kwargs):
        """Forward function for LM head."""
        embeddings = self.embeddings(input_ids, position_ids=position_ids)
        hidden_states = self.embed_dropout(embeddings)

        if return_hidden_outputs:
            hidden_outputs = [hidden_states.detach().cpu()]
        else:
            hidden_outputs = None

        for mixer in self.mixer:
            hidden_states = mixer(hidden_states)
            if return_hidden_outputs:
                hidden_outputs.append(hidden_states.detach().cpu())

        logits = self.lm_head(hidden_states)
        # print(">> Logits shape:", logits.shape)  # Should be [B, T, vocab_size]
        # print(">> Logits stats: mean =", logits.mean().item(),
        #       "std =", logits.std().item(),
        #       "min =", logits.min().item(),
        #       "max =", logits.max().item())
        # print(">> Labels shape:", labels.shape)
        # print(">> Labels max =", labels.max().item(), "| vocab size =", self.vocab_size)
        # print(">> LM Head weight shape:", self.lm_head.weight.shape)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            shifted_probe_logits = logits[:, :-1, :]
            shifted_labels = labels[:, 1:]
            loss = loss_fn(shifted_probe_logits.reshape(-1, shifted_probe_logits.size(-1)), shifted_labels.reshape(-1))

        CausalLMOutput = namedtuple('CausalLMOutput', ['loss', 'logits'])
        return CausalLMOutput(loss=loss, logits=logits), hidden_outputs, None


def step_num(dataset: Dataset):
    """Determine step size with dataset."""
    return CYCLE_VALUE if CYCLE_MODE == 'steps' else math.ceil(len(dataset) / BATCH_SIZE) * CYCLE_VALUE


def epoch_num(dataset: Dataset):
    """Determine epoch num with dataset."""
    return CYCLE_VALUE if CYCLE_MODE == 'epochs' else math.ceil(CYCLE_VALUE / math.ceil(len(dataset) / BATCH_SIZE))


def split_text_into_chunks(text, chunk_size=512):
    """Split a string into chunks of at most `chunk_size` words."""
    words = text.split()
    # Generate chunks by slicing the list of words
    return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]


def get_shuffle_indices(N: int, torch_seed: int = seed) -> list:
    """Get the shuffle index for a seed."""
    g = torch.Generator().manual_seed(torch_seed)
    return torch.randperm(N, generator=g).tolist()


def split_dataset(dataset, chunk_size=512):
    """Split the dataset."""
    new_records = []
    new_index = 0

    for record in dataset:
        text_chunks = split_text_into_chunks(record['text'], chunk_size)
        for chunk in text_chunks:
            new_records.append({'index': new_index, 'text': chunk})
            new_index += 1

    new_dataset = Dataset.from_dict({'index': [r['index'] for r in new_records],
                                    'text': [r['text'] for r in new_records]})
    return new_dataset


def save_checkpoint(model: torch.nn.Module, optimizer: AdamW, scheduler: Any, epoch: int, epoch_step: int, global_step: int):
    """Save a checkpoint with the model, optimizer, scheduler, and dataset state."""
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f'checkpoint_{epoch}_{global_step}.pt')
    model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
    torch.save({
        'model_state_dict': model_state_dict,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'epoch': epoch,
        'epoch_step': epoch_step,
        'global_step': global_step,
        'rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state_all()
    }, checkpoint_path)
    print(f'Checkpoint saved to {checkpoint_path}')


def load_checkpoint(checkpoint_path: int, dataset: Dataset, tokenizer, starting_over: bool = False):
    """Load a checkpoint and restore the training state."""
    checkpoint = torch.load(checkpoint_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Restore model and wrap in DataParallel if needed
    if NAME == 'gpt2-base':
        model = GPT2LMHeadModel(config=GPT2Config())
        model.resize_token_embeddings(len(tokenizer))
    elif NAME == 'LSTM':
        model = SimpleLMHeadModelNoFFN(
            layer=LSTMLayer,
            d_model=768,
            vocab_size=len(tokenizer),
            n_layer=4,
            max_position_embeddings=1024,
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )
    elif NAME == 'MAMBA':
        mamba_config = MambaConfig(d_model=768, d_intermediate=0, n_layer=12, vocab_size=len(tokenizer), 
            ssm_cfg={'layer': 'Mamba2'}, attn_cfg={'causal': True, 'num_heads': 12}, rms_norm=True, residual_in_fp32=True, 
            fused_add_norm=True, pad_vocab_size_multiple=16, tie_embeddings=True)
        model = MambaLMHeadModel(mamba_config)

    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)

    # Wrap in DataParallel if multiple GPUs are available
    if torch.cuda.device_count() > 1:
        print(f'Using {torch.cuda.device_count()} GPUs for resumed model')
        model = torch.nn.DataParallel(model)

    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=W_DECAY)
    if not starting_over:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler = get_scheduler(
            'linear', optimizer=optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=step_num(dataset)
        )
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    else:
        scheduler = get_scheduler(
            'linear', optimizer=optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=step_num(dataset)
        )

    # Recompute unused portion of the dataset
    full_indices = get_shuffle_indices(len(dataset), seed)
    start_index = checkpoint['epoch_step'] * BATCH_SIZE
    remaining_indices = full_indices[start_index:]
    unused_subset = dataset.select(remaining_indices)

    print(f'now size: {len(unused_subset)} instead of full {len(dataset)}')
    dataloader = DataLoader(unused_subset, batch_size=BATCH_SIZE, shuffle=False)

    metadata = {
        'epoch': checkpoint['epoch'],
        'epoch_step': checkpoint['epoch_step'],
        'global_step': checkpoint['global_step'],
        'rng_state': checkpoint['rng_state'],
        'cuda_rng_state': checkpoint['cuda_rng_state']
    }
    return model, optimizer, scheduler, dataloader, metadata


def tokenize_function(example, tokenizer):
    """Tokenize the dataset examples."""
    return tokenizer(example['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH)


def prepare_dataloader(dataset: Dataset, batch_size: int):
    """Prepare the DataLoader for the unused portion of the dataset."""
    g = torch.Generator()
    g.manual_seed(seed)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, generator=g)
    return dataloader


def load_merge_dataset(path: str, load_all: bool = True, load_split: List[str] = []):
    """Load dataset, and merge specified or all splits into one single Dataset object."""
    dataset_dict = load_dataset(path)
    # return dataset_dict
    if load_all:
        merged_dataset = concatenate_datasets(list(dataset_dict.values()))  # Merge all splits into one Dataset
    else:
        missing_splits = [split for split in load_split if split not in dataset_dict]
        if missing_splits:
            raise ValueError(f'Specified splits not found in the dataset: {missing_splits}')
        merged_dataset = concatenate_datasets([dataset_dict[split] for split in load_split])  # Merge only the specified splits
    return merged_dataset


def save_model_safely(model, output_dir):
    """Save model with one or more GPU."""
    if hasattr(model, 'module'):  # i.e., it's wrapped in DataParallel
        model_to_save = model.module
    else:
        model_to_save = model
    model_to_save.save_pretrained(output_dir)


def train(tokenizer_type: str = 'gpt2'):
    """Main training loop with checkpointing and dataset tracking. tokenizer_type is gpt2 or wordlevel."""
    # Load the dataset
    dataset = split_dataset(load_merge_dataset(DATASET_PATH, load_all=IS_TRAIN_ALL, load_split=TRAIN_SPLIT))
    # dataset = load_from_disk('/home/[censored]/trabank-dev/data/childes_dataset/')
    print('Dataset example:')
    print(dataset['text'][0])
    if tokenizer_type == 'gpt2':
        tokenizer = AutoTokenizer.from_pretrained('gpt2')
    else:
        tokenizer = TrainableWordTokenizer(VOCAB_FILE)

    tokenizer.pad_token = tokenizer.eos_token
    wandb.init(project=WANDB_PROJ, name=WANDB_EXP_NAME, resume='allow')
    # Tokenize the dataset
    tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    # tokenized_dataset.save_to_disk("/scratch/[censored]_root/[censored]2/[censored]/experiments/childes_warmup_tokenized/")
    # tokenized_dataset = load_from_disk("/scratch/[censored]_root/[censored]2/[censored]/experiments/data2/")
    # Check if resuming from a checkpoint
    if os.path.exists(CHECKPOINT_DIR):
        latest_checkpoint = sorted(os.listdir(CHECKPOINT_DIR), key=lambda x: int(x.split('_')[-1][:-3]))[-1]
        print(f'Resuming from checkpoint: {latest_checkpoint}')
        checkpoint_path = os.path.join(CHECKPOINT_DIR, latest_checkpoint)
        model, optimizer, scheduler, dataloader, metadata = load_checkpoint(checkpoint_path, tokenized_dataset, tokenizer, starting_over=True)

        start_epoch = metadata['epoch']
        global_block_no = metadata['global_step']
        epoch_step = metadata['epoch_step']
        torch.set_rng_state(metadata['rng_state'])
        if torch.cuda.is_available():
            torch.cuda.set_rng_state_all(metadata['cuda_rng_state'])
    else:
        print(f'Starting training of {NAME} from scratch.')
        if NAME == 'gpt2-base':
            model = GPT2LMHeadModel(config=GPT2Config())
            model.resize_token_embeddings(len(tokenizer))
        elif NAME == 'LSTM':
            model = SimpleLMHeadModelNoFFN(
                d_model=768,
                vocab_size=len(tokenizer),
                n_layer=4,
                max_position_embeddings=1024,
                device='cuda' if torch.cuda.is_available() else 'cpu'
            )
        elif NAME == 'MAMBA':
            # model = SimpleLMHeadModelNoFFN(
            #     layer=MambaLayer,
            #     d_model=768,
            #     vocab_size=len(tokenizer),
            #     n_layer=12,
            #     max_position_embeddings=1024,
            #     device='cuda' if torch.cuda.is_available() else 'cpu'
            # )
            mamba_config = MambaConfig(d_model=768, d_intermediate=0, n_layer=12, vocab_size=len(tokenizer), 
                ssm_cfg={'layer': 'Mamba2'}, rms_norm=True, residual_in_fp32=True, 
                fused_add_norm=True, pad_vocab_size_multiple=16, tie_embeddings=True)
            model = MambaLMHeadModel(mamba_config)


        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        if torch.cuda.device_count() > 1:
            print(f'Using {torch.cuda.device_count()} GPUs with DataParallel')
            model = torch.nn.DataParallel(model)
        
        mamba_params = []
        attn_params = []
        no_decay = []
        g1 = []
        g2 = []
        g3 = []

        # for name, param in model.named_parameters():
        #     if not param.requires_grad:
        #         continue
        #     if 'norm' in name or 'bias' in name:
        #         no_decay.append(param)
        #         g3.append(name)
        #     # elif '.0' in name or '.2' in name or '.4' in name or '.6' in name or '.8' in name or '.10' in name:
        #     elif '.1' in name or '.3' in name or '.5' in name or '.7' in name or '.9' in name or '.11' in name:
        #         attn_params.append(param)
        #         g2.append(name)
        #     else:
        #         mamba_params.append(param)
        #         g1.append(name)

        # print(len(mamba_params), len(attn_params), len(no_decay))
        # print(g1)
        # print("===")
        # print(g2)
        # print("===")
        # print(g3)
        # print("===")
        # exit()

        # optimizer_grouped_parameters = [
        #     {"params": mamba_params, "weight_decay": W_DECAY, "lr": LEARNING_RATE},
        #     {"params": attn_params, "weight_decay": W_DECAY * 0.5, "lr": LEARNING_RATE},
        #     {"params": no_decay, "weight_decay": 0.0},
        # ]

        optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=W_DECAY, betas=(0.9, 0.95))
        scheduler = get_scheduler(
            'linear', optimizer=optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=step_num(dataset)
        )
        dataloader = prepare_dataloader(tokenized_dataset, BATCH_SIZE)
        start_epoch = 0
        global_block_no = 0
        epoch_step = 0

    # Move model to device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    complete_dataloader = prepare_dataloader(tokenized_dataset, BATCH_SIZE)
    # Training loop
    model.train()

    effective_epochs = epoch_num(dataset)
    for epoch in range(start_epoch, effective_epochs):
        if global_block_no == 0:
            save_checkpoint(model, optimizer, scheduler, epoch, epoch_step, global_block_no)
        epoch_loss = 0
        if epoch == start_epoch:
            progress_bar = tqdm(dataloader, desc=f'Epoch {epoch + 1}/{effective_epochs}')
        else:
            progress_bar = tqdm(complete_dataloader, desc=f'Epoch {epoch + 1}/{effective_epochs}')

        for batch_no, batch in enumerate(progress_bar):
            if NAME == 'LSTM':
                vocab_size = model.vocab_size
            elif NAME == 'MAMBA':
                vocab_size = model.config.vocab_size
            else:
                vocab_size = model.module.config.vocab_size

            assert batch['input_ids'].max() < vocab_size, f"Error: Input ID = {batch['input_ids'].max()} exceeds vocab size={vocab_size} on {global_block_no}"

            if global_block_no <= -1:
                global_block_no += 1
                epoch_step += 1
                continue
            batch = {key: value.to(device) for key, value in batch.items()}

            # Forward pass
            try:
                if NAME == 'LSTM':
                    outputs, hidden_outputs, _ = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['input_ids'])
                    loss = outputs.loss.mean()
                elif NAME == 'MAMBA':
                    outputs = model(input_ids=batch['input_ids'])
                    logits = outputs.logits
                    labels = batch['input_ids']
                    loss_fn = nn.CrossEntropyLoss()
                    shifted_probe_logits = logits[:, :-1, :]
                    shifted_labels = labels[:, 1:]
                    loss = loss_fn(shifted_probe_logits.reshape(-1, shifted_probe_logits.size(-1)), shifted_labels.reshape(-1))
                else:
                    outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['input_ids'])
                    loss = outputs.loss.mean()

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                total_norm = 0.0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)  # L2 norm
                        total_norm += param_norm.item() ** 2

                total_norm = total_norm ** 0.5  # L2 norm of all gradients combined
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)

                optimizer.step()
                scheduler.step()

                # Log loss
                epoch_loss += loss.item()
                progress_bar.set_postfix({'loss': loss.item()})

                # Increment global block number
                global_block_no += 1
                epoch_step += 1

                # wandb
                if global_block_no % WANDB_LOG_EVERY == 0:
                    wandb.log({'batch_loss': loss.item(), 'learning_rate': optimizer.param_groups[0]['lr'], 'epoch': epoch + 1, 'step_in_epoch': epoch_step, 'grad_norm': total_norm})

                # Save checkpoint
                if global_block_no % CHECKPOINT_INTERVAL == 0 or global_block_no in [0, 150, 300]:
                    save_checkpoint(model, optimizer, scheduler, epoch, epoch_step, global_block_no)
            except Exception as e:
                print(f'error: {e}')
                raise e

            # stop if it's 'steps' mode, at last epoch, and enough steps are trained (no need to finish the epoch)
            if CYCLE_MODE == 'steps' and epoch == effective_epochs - 1 and global_block_no == CYCLE_VALUE:
                print('Enough steps are trained.')
                print(f'Epoch {epoch + 1} completed.')
                break

        print(f'Epoch {epoch + 1} completed. Average loss: {epoch_loss / epoch_step}')
        epoch_step = 0

    # Save the final model
    save_model_safely(model, OUTPUT_DIR)
    print(f'Model saved to {OUTPUT_DIR}.')
    wandb.finish()


if __name__ == '__main__':
    train(TOKENIZER_TYPE)
