import random, torch
import os, yaml
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import pickle
import warnings
import argparse
from time import time_ns
warnings.simplefilter("ignore", category=FutureWarning)
warnings.simplefilter("ignore", category=UserWarning)
from utils.at_utils import ValNodeTokensHDF5, TreeDataset_HDF5
from transformers import TrainingArguments, Trainer
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)
from evaluate import load
from peft import LoraConfig, get_peft_model
import anytree as at
from h5py import File
from random import shuffle

set_seed(16)

from utils.loading_utils import load_value_model, load_for_anthropic, load_for_summary
from utils.search_utils import parse_value_model_iter

A = argparse.ArgumentParser()
A.add_argument("--dataset", type=str, choices=["anthropic", "summary"], default="anthropic")
A.add_argument("--data_dir", type=str, default=None, help="Directory containing the data for training; can be any data which is in the format used for training value models.")
A.add_argument("--output_dir", type=str, default="training_output/")
A.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train the value model for. Default is 1.")
A.add_argument("--num_training_steps", type=int, default=None, help="Number of training batches to use (After this many batches, the learning rate will be zero). If not set, it will be automatically calculated based on the dataset size and batch size.")
A.add_argument("--init_checkpoint", type=str, default=None, help="Path to the checkpoint from which to initialize the value model. If none is set, it will be initialized from TinyLlama 1.1 with a randomly initialized value head.")
A.add_argument("--batch_size", type=int, default=16)
A.add_argument("--lr", type=float, default=4e-5, help="Learning rate for training the value model.")
A.add_argument("--weight_decay", type=float, default=2e-3, help="Weight decay for the optimizer.")
A.add_argument("--no_warmup", action="store_true", help="If set, the learning rate will not be warmed up at the beginning of training. This is useful for iterative training where the model has already been trained.")
A.add_argument("--grad_accumulation_steps", type=int, default=1, help="Number of gradient accumulation steps to use. Default is 1 (no accumulation).")
A.add_argument("--value_model_dir", type=str, default="value_models/", help="Directory containing the value models for the objectives; it is assumed that the directory contains a subdirectory for each objective used, with the subdirectories containing the trained value model weights")
A.add_argument("--value_model_iter", type=str, default="0,0,0", help="The iteration of the value model to use; this is used to load the correct checkpoint from the value_model_dir")
args = A.parse_args()
data_dir = args.data_dir
output_dir = args.output_dir

args.no_warmup = True # No need for warmup when performing distillation

if args.dataset == "anthropic":
    loaded_assets = load_for_anthropic(include_gen_model=False, include_inputs=False, include_rewards=False, base_model_type="llama")
elif args.dataset == "summary":
    loaded_assets = load_for_summary(include_gen_model=False, include_inputs=False, include_rewards=False, base_model_type="llama")
tokenizer = loaded_assets["gen_tokenizer"]
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

from utils.gen_utils import LlamaValueModel
model = LlamaValueModel.from_pretrained("TinyLlama/TinyLlama_v1.1", num_labels=3, problem_type='regression', torch_dtype=torch.float32)
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
model.model.padding_idx = tokenizer.pad_token_id
model.model.config.pad_token_id = tokenizer.pad_token_id

# We only want to add a new PEFT adapter if we are not loading from a checkpoint (i.e. we are training from scratch rather
# than doing iterative training)
if args.init_checkpoint is None:
    peft_config = LoraConfig(
        r=256,
        lora_alpha=512,
        lora_dropout=0.2,
        bias="none",
        task_type="SEQ_CLS",
        # Include the QKV projections and the feed-forward layers
        target_modules=target_modules,
        # Also keep the classifier unfrozen
        modules_to_save=["score"],
        inference_mode=False
    )
    model = get_peft_model(model, peft_config)

dataset_valid_objectives = ["help", "harm", "humor"] if args.dataset == "anthropic" else ["summarization", "faithful"]
iter_nums = parse_value_model_iter(args.value_model_iter, dataset_valid_objectives)
teacher_model_paths = [f"{args.value_model_dir}/iter_{iter}/{objective}" for objective, iter in iter_nums.items() if iter is not None]
teacher_models = [load_value_model(dataset="anthropic", checkpoint=path) for path in teacher_model_paths]

# Extract the tree containing the training data. Note that the ValNodeTokensHDF5 class must be imported for this to work
# We need a dictionary mapping prompt names to their respective trees in order to easily access the tokens from the hdf5 file
train_roots = {}
for file in os.listdir(os.path.join(data_dir,"train")):
    if file.endswith(".pkl"):
        prompt_name = os.path.basename(file).split("_")[0]
        with open(os.path.join(data_dir,"train",file), "rb") as f:
            train_roots[prompt_name] = pickle.load(f)
val_roots = {}
for file in os.listdir(os.path.join(data_dir,"val")):
    if file.endswith(".pkl"):
        prompt_name = os.path.basename(file).split("_")[0]
        with open(os.path.join(data_dir,"val",file), "rb") as f:
            val_roots[prompt_name] = pickle.load(f)

tokens_file = os.path.join(data_dir, "all_tokens.hdf5")
assert os.path.exists(tokens_file), f"Tokens file {tokens_file} does not exist. Please ensure it is present in the data directory."

# NOTE: The order in which the teacher models are listed determines which value head in the student model corresponds to which objective
class TreeDataset_Distill(TreeDataset_HDF5):
    def __init__(self, root_dict, tokens_file, pad_token_id=32000, teacher_models=[]):
        assert isinstance(tokens_file, str) and tokens_file.endswith(".hdf5"), "tokens_file must be a path to an HDF5 file"
        assert len(teacher_models) > 0, "At least one teacher model must be provided"
        self.tokens_file = File(tokens_file, "r")
        self.node_list = []
        # Create a list of all leaf nodes from all trees
        for prompt_name, root in root_dict.items():
            root.name = prompt_name
            assert type(root) is ValNodeTokensHDF5
            cur_node_list = [node for node in at.PreOrderIter(root) if node.is_leaf]
            self.node_list = self.node_list + cur_node_list
        # print("Size of dataset:", len(self.node_list))
        self.pad_token_id = pad_token_id
        shuffle(self.node_list)
        self.teacher_models = teacher_models
        for teacher_model in self.teacher_models:
            teacher_model.eval()

    def __len__(self):
        return len(self.node_list)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        node = self.node_list[idx]
        dataset_name = node.root.name
        leaf_tokens_indx_path = node.get_tokens_indx_path()
        sequence_list = [torch.Tensor(self.tokens_file[dataset_name][indx]) for indx in leaf_tokens_indx_path]
        sequence_list_unpadded = [seq[seq != self.pad_token_id].unsqueeze(0) for seq in sequence_list]
        num_prompt_tokens = sequence_list_unpadded[0].shape[1]
        num_non_prompt_tokens = sum([seq.shape[1] for seq in sequence_list_unpadded]) - num_prompt_tokens
        node_sequence = torch.cat(sequence_list_unpadded, dim=1)
        if node_sequence.dtype != torch.int64:
            node_sequence = node_sequence.to(torch.int64)
        node_sequence = node_sequence.to(self.teacher_models[0].device)
        # Get the targets for every token from the teacher models
        with torch.no_grad():
            # Get the value predictions from each teacher model and average them
            value_predictions = []
            for teacher_model in self.teacher_models:
                # Each teacher model should return a value prediction for every non-prompt token, meaning every token which does not
                # belong to the root node.
                outputs = teacher_model(input_ids=node_sequence.squeeze().unsqueeze(0), 
                                        attention_mask=torch.ones_like(node_sequence).squeeze().unsqueeze(0),
                                        num_logits_to_return=num_non_prompt_tokens)
                logits = outputs.logits.squeeze().cpu()
                if logits.numel() == 1:
                    logits = logits.unsqueeze(0)
                value_predictions.append(logits)
            values = torch.stack(value_predictions, dim=0).squeeze()
            if values.dim() == 1:
                values = values.unsqueeze(1)
            values = values.transpose(0,1)  # [num_non_prompt_tokens, num_objectives]
        # We could try to correct the targets based on the rewards from the reward models, but I'm not trying that yet
        # value = node.value[self.objective] if type(node.value) is dict else node.value

        return {"input_ids": node_sequence.squeeze(), "attention_mask": torch.ones_like(node_sequence.squeeze()), "labels": values, "first_non_prompt_indx": num_prompt_tokens}

from transformers import DataCollatorWithPadding
class TokenRegressionCollatorND:
    def __init__(self, tokenizer, num_targets=3, ignore_special=True, pad_value=0.0):
        self.base = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)
        self.num_targets = num_targets
        self.ignore_special = ignore_special
        self.pad_value = pad_value

    def __call__(self, features):
        labels_list = [f.pop("labels") for f in features]   # each: [Li, D]
        batch = self.base(features)  # pads input_ids, attention_mask, special_tokens_mask, etc.
        B = len(labels_list)
        Lmax = batch["input_ids"].size(1)
        D = self.num_targets

        labels = torch.full((B, Lmax, D), self.pad_value, dtype=torch.float32)
        mask = torch.zeros((B, Lmax), dtype=torch.bool)

        for i, labs in enumerate(labels_list):
            first_non_prompt_indx = features[i]["first_non_prompt_indx"]
            arr = torch.tensor(labs, dtype=torch.float32)  # [Li, D]
            L = min(arr.size(0), Lmax)
            labels[i, :L, :] = arr[:L, :]
            mask[i, :L] = True # The labels actually start at the beginning, and extra padding is added at the end for the prompt tokens

        # optionally ignore special tokens
        if self.ignore_special and "special_tokens_mask" in batch:
            stm = batch["special_tokens_mask"].bool()
            mask &= ~stm

        batch["labels"] = labels            # [B, Lmax, D]
        batch["labels_mask"] = mask         # [B, Lmax]
        return batch

class TokenRegressionTrainerND(Trainer):
    def __init__(self, *args, dimension_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.dimension_weights = None
        if dimension_weights is not None:
            self.dimension_weights = torch.tensor(dimension_weights, dtype=torch.float32)

    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        labels = inputs.pop("labels")              # [B, L, D]
        B, L, D = labels.size()
        labels_mask = inputs.pop("labels_mask")    # [B, L]
        # We do not actually worry about the attention mask
        attention_mask = inputs.get("attention_mask")  # [B, L] or None

        # Get the number of unpadded labels per batch element
        num_labels_per_batch = labels_mask.sum(dim=1)  # [B]
        max_num_labels = num_labels_per_batch.max().item()

        first_non_prompt_indx = inputs.pop("first_non_prompt_indx", None)  # Not used by the model
        # Note that we are having the model return logits for all input tokens so that we can extract the ones which we need to compute
        # the loss over.
        outputs = model(**inputs, num_logits_to_return=inputs["input_ids"].size(1))
        # The teacher models only return logits for the non-prompt tokens, but these labels are then padded so that they can be batched.
        # Specifically, they are padded to the max sequence length including both prompt and non-prompt tokens.
        # When the labels are batched together, all the padding is placed on the right, even though the actual starting index of the 
        # non-prompt tokens may not be aligned. Therefore, we need to get logits from the student model for all tokens, then extract the
        # ones which correspond to non-prompt tokens
        # For the ith element in the batch, the range of tokens we can get logits for starts at first_non_prompt_indx[i] and goes to the end
        # of the sequence. We will only take the first max_num_labels logits however, since anything beyond that must be padding.
        device = outputs.logits.device
        logits = outputs.logits
        base = torch.arange(max_num_labels, device=device).unsqueeze(0).expand(B, -1)
        idx = first_non_prompt_indx.unsqueeze(1) + base
        mask = base < num_labels_per_batch.unsqueeze(1)   # [B, max_num_labels]
        batch_idx = torch.arange(B, device=device).unsqueeze(1).expand_as(idx)
        max_idx = max_num_labels + first_non_prompt_indx.max().item()
        if max_idx >= logits.shape[1]:
            logits = torch.nn.functional.pad(logits, (0,0,0,max_idx-logits.shape[1]), value=0.0)
        logits = logits[batch_idx, idx]  # [B, max_num_labels, D]

        # Now remove the padding that is shared across all batch elements of the labels
        labels = labels[:, labels_mask.any(dim=0), :]
        labels_mask = labels_mask[:, labels_mask.any(dim=0)]

        # squared error per dim
        sq_err = (logits - labels) ** 2            # [B, L, D]
        if self.dimension_weights is not None:
            w = self.dimension_weights.to(logits.device).view(1, 1, -1)
            sq_err = sq_err * w

        per_token = sq_err.mean(dim=-1)            # [B, L] (average over D)
        mask = labels_mask
        # if attention_mask is not None:
        #     mask = mask & attention_mask.bool()

        eps = 1e-8
        loss = (per_token * mask).sum() / (mask.sum().clamp_min(1) + eps)
        return (loss, outputs) if return_outputs else loss
        return loss

# Right now we don't bother with evaluation
def compute_metrics(eval_pred, inputs=None):
    preds, labels = eval_pred.predictions, eval_pred.label_ids
    # Some models return (logits, ...); take first
    if isinstance(preds, tuple):
        preds = preds[0]

    # Expect shapes [B, L, D]
    assert preds.ndim == 3, f"preds should be [B, L, D], got {preds.shape}"
    assert labels.ndim == 3, f"labels should be [B, L, D], got {labels.shape}"
    B, L, D = preds.shape

    mse = 0
    # Optional: per-dimension metrics
    mse_per_dim = np.zeros(D)

    # Return plain Python floats
    out = {"mse": mse}
    # Also log per-dim (as separate scalars)
    out.update({f"mse_dim_{i}": float(v) for i, v in enumerate(mse_per_dim)})
    return out

dataset_tr = TreeDataset_Distill(train_roots, tokens_file, pad_token_id=tokenizer.pad_token_id, teacher_models=teacher_models)
dataset_val = TreeDataset_Distill(val_roots, tokens_file, pad_token_id=tokenizer.pad_token_id, teacher_models=teacher_models)

# Need to compute the number of training batches in order to set the learning rate scheduler
num_train_epochs = args.num_epochs
batch_size = args.batch_size
num_training_steps = (dataset_tr.__len__()//(batch_size*args.grad_accumulation_steps))*num_train_epochs if args.num_training_steps is None else args.num_training_steps

train_args = TrainingArguments(
    num_train_epochs=num_train_epochs,
    evaluation_strategy = "no", # Currently no evaluation
    save_strategy = "steps",
    save_steps = 50 if args.dataset == 'anthropic' else 200,
    save_total_limit=5,
    learning_rate=args.lr,
    lr_scheduler_type='linear',
    lr_scheduler_kwargs={'num_warmup_steps':100 if not args.no_warmup else 0, 'num_training_steps':num_training_steps},
    optim='adafactor',
    gradient_accumulation_steps=args.grad_accumulation_steps,
    per_device_train_batch_size=batch_size,
    report_to="tensorboard",
    logging_dir=output_dir,
    logging_steps=10,
    weight_decay=args.weight_decay,
    output_dir=output_dir,
    ddp_find_unused_parameters=False, # Recommended for performance
    bf16=True if args.dataset == 'summary' else False, # Use mixed precision training with summary dataset
    remove_unused_columns=False, # Need this because we use extra dataset features for the collator
    )

# Save all arguments to a yaml file
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "training_args.yaml"), "w") as f:
    yaml.dump(vars(train_args) | vars(args), f)
training_logfile = open(os.path.join(output_dir, "training_log.txt"), "w")
def log(message):
    print(message)
    training_logfile.write(message + "\n")
    training_logfile.flush()

from transformers.optimization import get_linear_schedule_with_warmup
class TrainerWithLinearWarmupSchedule(TokenRegressionTrainerND):
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        """
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.

        Args:
            num_training_steps (int): The number of training steps to do.
        """
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.lr_scheduler_kwargs["num_warmup_steps"], 
                num_training_steps=self.args.lr_scheduler_kwargs["num_training_steps"])
            self._created_lr_scheduler = True
        return self.lr_scheduler

train_start_time = time_ns()
from transformers import TrainerCallback
class MilestoneTimerCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            if state.global_step % train_args.eval_steps == 0:
                log(f"Reached step {state.global_step} after {(time_ns() - train_start_time)/1e9} seconds.")

trainer = TrainerWithLinearWarmupSchedule(
    model,
    train_args,
    train_dataset=dataset_tr,
    eval_dataset=None,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    # callbacks=[MilestoneTimerCallback()], # TODO: bring this back when we have an eval dataset
    data_collator=TokenRegressionCollatorND(tokenizer, num_targets=3), # TODO: make num_targets dynamic
    # dimension_weights=[1.0, 0.5, 2.0],  # optional
)

trainer.train()
log(f"Training completed in {(time_ns() - train_start_time) / 1e9} seconds.")

# Save the model (this should be the best one since we used load_best_model_at_end=True)
model.save_pretrained(output_dir)
