import os
import json
import ast
import random
from argparse import ArgumentParser
from tqdm import tqdm
import pickle as pkl
import numpy as np

import torch
from transformers import (
    AutoTokenizer,
    set_seed,
)
from lorax.custom_hf import (
    LoraxLlamaForCausalLM,
    LoraxQwen2ForCausalLM,
    wrap_linear,
)

from lorax.utils import set_no_grad
DATASET_PATH = {
    "elife": "/root/lorax/data/elife/train.json",
    "cochrane": "/root/lorax/data/cochrane/train.json",
    "genetics": "/root/lorax/data/plos_genetics/train.json",
}


def collect_repr(model, samples, source_repr_for_probe, target_repr_for_probe, verbose=True):
    model.eval()
    for (text1, text2) in tqdm(samples, desc="Processing positive conversations"):
        if verbose:
            print(f"Text1: {text1}")
            print(f"Text2: {text2}")
            verbose = False
        
        text1_input_ids = tokenizer.encode(text1, return_tensors="pt", add_special_tokens=False)
        text2_input_ids = tokenizer.encode(text2, return_tensors="pt", add_special_tokens=False)

        text1_attention_mask = torch.ones_like(text1_input_ids)
        text2_attention_mask = torch.ones_like(text2_input_ids)

        text1_inputs = {
            "input_ids": text1_input_ids,
            "attention_mask": text1_attention_mask,
        }

        text2_inputs = {
            "input_ids": text2_input_ids,
            "attention_mask": text2_attention_mask,
        }

        text1_inputs = {k: v.to(device) for k, v in text1_inputs.items()}
        text2_inputs = {k: v.to(device) for k, v in text2_inputs.items()}

    
        with torch.no_grad():
            text1_outputs = model(
                **text1_inputs,
                lora_B_idx=None,
                contrastive_targets=[],
                representations=None,
                representations_for_probe=source_repr_for_probe
            )
            text2_outputs = model(
                **text2_inputs,
                lora_B_idx=None,
                contrastive_targets=[],
                representations=None,
                representations_for_probe=target_repr_for_probe
            )

if __name__ == "__main__":

    parser = ArgumentParser()
    parser.add_argument("--model_name", type=str)
    parser.add_argument("--datasets", type=ast.literal_eval)
    args = parser.parse_args()

    set_seed(42)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model_kwargs = {
        "pretrained_model_name_or_path": args.model_name,
        "trust_remote_code": False,
        "torch_dtype": torch.bfloat16,
        "device_map": device,
    }

    if "Qwen2" in args.model_name:
        config = LoraxQwen2ForCausalLM.config_class.from_pretrained(args.model_name)
        config._attn_implementation = "flash_attention_2"
        model = LoraxQwen2ForCausalLM.from_pretrained(
            **model_kwargs,
            config=config
        )
    elif "Llama" in args.model_name:
        config = LoraxLlamaForCausalLM.config_class.from_pretrained(args.model_name)
        config._attn_implementation = "flash_attention_2"
        model = LoraxLlamaForCausalLM.from_pretrained(
            **model_kwargs,
            config=config
        )
    else:
        raise ValueError(f"Model {args.model_name} not supported!")

    lorax_config = {
        "lora_r": 0,
        "num_loras": 0,
    }
    wrap_linear(
        model=model,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        config=lorax_config,
    )

    num_hidden_layers = config.num_hidden_layers
    source_repr_for_probe = {}
    target_repr_for_probe = {}
    representations_for_probe = {}
    for i in range(num_hidden_layers):
        source_repr_for_probe[f"layers.{i}.self_attn.q_proj"] = []
        source_repr_for_probe[f"layers.{i}.self_attn.k_proj"] = []
        source_repr_for_probe[f"layers.{i}.self_attn.v_proj"] = []
        source_repr_for_probe[f"layers.{i}.self_attn.o_proj"] = []
        source_repr_for_probe[f"layers.{i}.mlp.gate_proj"] = []
        source_repr_for_probe[f"layers.{i}.mlp.up_proj"] = []
        source_repr_for_probe[f"layers.{i}.mlp.down_proj"] = []

        target_repr_for_probe[f"layers.{i}.self_attn.q_proj"] = []
        target_repr_for_probe[f"layers.{i}.self_attn.k_proj"] = []
        target_repr_for_probe[f"layers.{i}.self_attn.v_proj"] = []
        target_repr_for_probe[f"layers.{i}.self_attn.o_proj"] = []
        target_repr_for_probe[f"layers.{i}.mlp.gate_proj"] = []
        target_repr_for_probe[f"layers.{i}.mlp.up_proj"] = []
        target_repr_for_probe[f"layers.{i}.mlp.down_proj"] = []

        representations_for_probe[f"layers.{i}.self_attn.q_proj"] = []
        representations_for_probe[f"layers.{i}.self_attn.k_proj"] = []
        representations_for_probe[f"layers.{i}.self_attn.v_proj"] = []
        representations_for_probe[f"layers.{i}.self_attn.o_proj"] = []
        representations_for_probe[f"layers.{i}.mlp.gate_proj"] = []
        representations_for_probe[f"layers.{i}.mlp.up_proj"] = []
        representations_for_probe[f"layers.{i}.mlp.down_proj"] = []
    


    data = []
    for dataset in args.datasets:
        data += json.load(open(DATASET_PATH[dataset], "r"))


    positive_samples = []
    negative_samples = []
    for i, d in enumerate(data):
        positive_samples.append(
            (
                d["source"],
                d["target"],
            )
        )
        # random sample a negative sample, exclude the same source
        negative_d = random.choice(data)
        while negative_d["source"] == d["source"]:
            negative_d = random.choice(data)
        negative_samples.append(
            (
                d["source"],
                negative_d["target"],
            )
        )


    positive_samples = random.sample(positive_samples, 1500)
    negative_samples = random.sample(negative_samples, 1500)
    print(f"Total conversations: {len(positive_samples) + len(negative_samples)}")

    # collect representations
    collect_repr(
        model=model,
        samples=positive_samples,
        source_repr_for_probe=source_repr_for_probe,
        target_repr_for_probe=target_repr_for_probe,
    )
    collect_repr(
        model=model,
        samples=negative_samples,
        source_repr_for_probe=source_repr_for_probe,
        target_repr_for_probe=target_repr_for_probe,
    )

    for key in source_repr_for_probe.keys():
        for source, target in zip(source_repr_for_probe[key], target_repr_for_probe[key]):
            representations_for_probe[key].append(np.concatenate((source, target), axis=0))

    labels = [1] * len(positive_samples) + [0] * len(negative_samples)

    model_name= args.model_name.split("/")[-1]
    save_dir = f"../data/probes/{model_name}"
    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, "probes_training_data.pkl"), "wb") as f:
        pkl.dump(representations_for_probe, f)
    with open(os.path.join(save_dir, "labels.pkl"), "wb") as f:
        pkl.dump(labels, f)