from peft import LoraConfig, PeftModel
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    AutoModelForSequenceClassification,
)
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from torch.optim import Adam

import argparse
import json
from tqdm import trange, tqdm

from modules.data_collator import RewardScoreDataCollatorWithPadding

from utils.data_processing import preprocess_dataset
from utils.env_management import save_config
import os

from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from sklearn.decomposition import PCA

import numpy as np

os.environ["WANDB_PROJECT"] = "finetuning-historical"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints
os.environ["WANDB_TAGS"] = "embedding"


parser = argparse.ArgumentParser(description="Train Reward")
parser.add_argument(
    "--config_file",
    type=str,
    default="configs/reward-base.json",
    help="config file",
)
parser.add_argument(
    "--test_only",
    action="store_true",
)
args = parser.parse_args()

config = json.load(open(args.config_file))

os.environ["WANDB_CACHE_DIR"] = config["wandb_cache_dir"].format(**config)

save_config(config, "embedding")

embedding_config = config["embedding"]
device = config["device"]  # for GPU usage or "cpu" for CPU usage
preprocess = config["preprocess"]
base_dir = config["base_dir"]

base_model = config["base_model"]


def save_embeddings_xs(xs_chosen, xs_rejected, score_chosen, score_rejected, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(xs_chosen, path + "_xs_chosen.pt")
    torch.save(xs_rejected, path + "_xs_rejected.pt")
    torch.save(score_chosen, path + "_score_chosen.pt")
    torch.save(score_rejected, path + "_score_rejected.pt")
    print(
        f"Saved embeddings to {path}_xs_chosen.pt, _xs_rejected.pt, _score_chosen.pt, _score_rejected.pt"
    )


def load_embeddings_xs(path):
    if not os.path.exists(path + "_xs_chosen.pt"):
        raise FileNotFoundError(f"{path} does not exist")
    xs_chosen = torch.load(path + "_xs_chosen.pt")
    xs_rejected = torch.load(path + "_xs_rejected.pt")
    score_chosen = torch.load(path + "_score_chosen.pt")
    score_rejected = torch.load(path + "_score_rejected.pt")
    return xs_chosen, xs_rejected, score_chosen, score_rejected


def get_embeddings_x_y(dataloader, part="train"):
    try:
        return load_embeddings_xs(
            embedding_config["save_embeddings_path"].format(part=part, **config)
        )
    except FileNotFoundError:
        print(
            f"File not found, generating embeddings for {part} and saving to {embedding_config['save_embeddings_path'].format(part=part, **config)}"
        )

    def get_hidden_states(model, inputs):
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states
            sequence_lengths = (
                torch.eq(inputs["input_ids"], model.config.pad_token_id)
                .int()
                .argmax(-1)
                - 1
            )
            sequence_lengths = sequence_lengths % inputs["input_ids"].shape[-1]
            # Get the hidden state (corresponding to the last token)
            last_hidden_state = hidden_states[-1][
                range(len(hidden_states[-1])), sequence_lengths
            ]
            return last_hidden_state.detach().cpu()

    xs_chosen = []
    xs_rejected = []
    score_chosen = []
    score_rejected = []
    for batch in tqdm(dataloader, desc="get_embeddings"):
        prompts_chosen = []
        for i in range(len(batch["chosen"][0]["role"])):
            chat_i = [{k: v[i] for k, v in b.items()} for b in batch["chosen"]]
            try:
                prompt = tokenizer.apply_chat_template(
                    chat_i,
                    tokenize=False,
                )
            except Exception as e:
                prompt = "\n".join(
                    [f"{message['role']}: {message['content']}" for message in chat_i]
                )
            prompts_chosen.append(prompt)
        model_inputs_chosen = tokenizer(
            prompts_chosen,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=1024,
        ).to(device)
        xs_chosen.append(get_hidden_states(model, model_inputs_chosen))

        prompt_rejected = []
        for i in range(len(batch["rejected"][0]["role"])):
            chat_i = [{k: v[i] for k, v in b.items()} for b in batch["rejected"]]
            try:
                prompt = tokenizer.apply_chat_template(
                    chat_i,
                    tokenize=False,
                )
            except Exception as e:
                prompt = "\n".join(
                    [f"{message['role']}: {message['content']}" for message in chat_i]
                )
            prompt_rejected.append(prompt)
        model_inputs_rejected = tokenizer(
            prompt_rejected,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=1024,
        ).to(device)
        xs_rejected.append(get_hidden_states(model, model_inputs_rejected))
        score_chosen.append(batch["chosen_score"].detach().cpu())
        score_rejected.append(batch["rejected_score"].detach().cpu())
    xs_chosen = torch.cat(xs_chosen, dim=0)
    xs_rejected = torch.cat(xs_rejected, dim=0)
    score_chosen = torch.cat(score_chosen, dim=0)
    score_rejected = torch.cat(score_rejected, dim=0)
    save_embeddings_xs(
        xs_chosen,
        xs_rejected,
        score_chosen,
        score_rejected,
        embedding_config["save_embeddings_path"].format(part=part, **config),
    )
    return xs_chosen, xs_rejected, score_chosen, score_rejected


tokenizer = AutoTokenizer.from_pretrained(
    base_model,
)
tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

if preprocess:
    ratings = load_from_disk(config["data_path"].format(**config))
    ratings = preprocess_dataset(
        ratings,
        tokenizer,
        config["max_input_length"],
        config["messages_template"],
        process_messages=False,
    )
    ratings.save_to_disk(config["data_path_preprocessed"].format(**config))
else:
    ratings = load_from_disk(config["data_path_preprocessed"].format(**config))

print(ratings)

checkpoint = embedding_config["model"].format(**config)


model = AutoModelForSequenceClassification.from_pretrained(
    base_model,
    num_labels=1,
    problem_type="regression",
    torch_dtype=config["torch_dtype"],
).to(device)

model.resize_token_embeddings(len(tokenizer))

model.config.pad_token_id = tokenizer.pad_token_id


if embedding_config["has_peft"]:
    model = PeftModel.from_pretrained(
        model, checkpoint, torch_dtype=config["torch_dtype"], adapter_name="default"
    )

    model = model.merge_and_unload()


train_set = embedding_config["train_set"]
valid_set = embedding_config["valid_set"]

dataloader = DataLoader(
    ratings[train_set],
    batch_size=embedding_config["training_batch_size"],
)

valid_dataloader = DataLoader(
    ratings[valid_set],
    batch_size=embedding_config["training_batch_size"],
)

test_dataloader = DataLoader(
    ratings["test"],
    batch_size=embedding_config["training_batch_size"],
)

if not args.test_only:
    xs, ys, score_chosen, score_rejected = get_embeddings_x_y(dataloader, part="train")
    valid_xs, valid_ys, valid_score_chosen, valid_score_rejected = get_embeddings_x_y(
        valid_dataloader, part="validation"
    )
test_xs, test_ys, test_score_chosen, test_score_rejected = get_embeddings_x_y(
    test_dataloader, part="test"
)
