from peft import LoraConfig, PeftModel
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    AutoModelForSequenceClassification,
)
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from torch.optim import Adam

import argparse
import json
from tqdm import trange, tqdm

from modules.data_collator import RewardScoreDataCollatorWithPadding

from utils.data_processing import preprocess_dataset
from utils.env_management import save_config
import os

from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from sklearn.decomposition import PCA

import numpy as np

os.environ["WANDB_PROJECT"] = "finetuning-historical"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints
os.environ["WANDB_TAGS"] = "embedding"


parser = argparse.ArgumentParser(description="Train Reward")
parser.add_argument(
    "--config_file",
    type=str,
    default="configs/reward-base.json",
    help="config file",
)
args = parser.parse_args()

config = json.load(open(args.config_file))

os.environ["WANDB_CACHE_DIR"] = config["wandb_cache_dir"].format(**config)

save_config(config, "embedding")

embedding_config = config["embedding"]
device = config["device"]  # for GPU usage or "cpu" for CPU usage
preprocess = config["preprocess"]
base_dir = config["base_dir"]

base_model = config["base_model"]


def save_embeddings_x_y(xs, ys, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(xs, path + "_xs.pt")
    torch.save(ys, path + "_ys.pt")


def load_embeddings_x_y(path):
    if not os.path.exists(path + "_xs.pt"):
        raise FileNotFoundError(f"{path} does not exist")
    xs = torch.load(path + "_xs.pt")
    ys = torch.load(path + "_ys.pt")
    return xs, ys


def get_embeddings_x_y(dataloader, part="train"):
    try:
        return load_embeddings_x_y(
            embedding_config["save_embeddings_path"].format(part=part, **config)
        )
    except FileNotFoundError:
        print(
            f"File not found, generating embeddings for {part} and saving to {embedding_config['save_embeddings_path'].format(part=part, **config)}"
        )
    xs = []
    ys = []
    for batch in tqdm(dataloader, desc="get_embeddings"):
        torch._C._cuda_emptyCache()
        inputs = {
            "input_ids": batch["input_ids"].to(device),
            "attention_mask": batch["attention_mask"].to(device),
        }
        # get hidden of reward_model
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states  # Get all hidden states

            sequence_lengths = (
                torch.eq(batch["input_ids"], model.config.pad_token_id).int().argmax(-1)
                - 1
            )
            sequence_lengths = sequence_lengths % batch["input_ids"].shape[-1]
            # Get the hidden state (corresponding to the last token)
            last_hidden_state = hidden_states[-1][
                range(len(hidden_states[-1])), sequence_lengths
            ]

            xs.append(last_hidden_state.detach().cpu())
            ys.append(batch[embedding_config["reward_column_name"]].detach().cpu())
    xs = torch.cat(xs, dim=0)
    ys = torch.cat(ys, dim=0)
    save_embeddings_x_y(
        xs,
        ys,
        embedding_config["save_embeddings_path"].format(part=part, **config),
    )
    return xs, ys


tokenizer = AutoTokenizer.from_pretrained(
    base_model,
)
tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

if preprocess:
    ratings = load_from_disk(config["data_path"].format(**config))
    ratings = preprocess_dataset(
        ratings, tokenizer, config["max_input_length"], config["messages_template"]
    )
    ratings.save_to_disk(config["data_path_preprocessed"].format(**config))
else:
    ratings = load_from_disk(config["data_path_preprocessed"].format(**config))

print(ratings)

checkpoint = embedding_config["model"].format(**config)


model = AutoModelForSequenceClassification.from_pretrained(
    base_model,
    num_labels=1,
    problem_type="regression",
    torch_dtype=config["torch_dtype"],
).to(device)

model.resize_token_embeddings(len(tokenizer))

model.config.pad_token_id = tokenizer.pad_token_id


if embedding_config["has_peft"]:
    model = PeftModel.from_pretrained(
        model, checkpoint, torch_dtype=config["torch_dtype"], adapter_name="default"
    )

    model = model.merge_and_unload()

data_collator = RewardScoreDataCollatorWithPadding(
    tokenizer=tokenizer,
    torch_dtype=config["torch_dtype"],
)

train_set = embedding_config["train_set"]
valid_set = embedding_config["valid_set"]

dataloader = DataLoader(
    ratings[train_set],
    batch_size=embedding_config["training_batch_size"],
    collate_fn=data_collator,
)

valid_dataloader = DataLoader(
    ratings[valid_set],
    batch_size=embedding_config["training_batch_size"],
    collate_fn=data_collator,
)

xs, ys = get_embeddings_x_y(dataloader, part="train")
valid_xs, valid_ys = get_embeddings_x_y(valid_dataloader, part="validation")
