import argparse
from tokenize import Double
import torch
from datasets import Dataset
import pyarrow as pa
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
import numpy as np

# Paths for different models
MODEL_PATH_MAP = {
    'llama2_13b': '../../LLMs/LLAMA/Llama-2-13b-chat-hf',
    'qwen3_14b': '../../LLMs/QWEN/Qwen3-14B'
}

Attack_MODEL_PATH_MAP = {
    'llama2_13b': './Edited_Model/llama2_13b',
    'qwen3_14b': './Edited_Model/qwen3_14b'
}

# Class that handles the model editing and analysis
class ModelEditor:
    def __init__(self, model, tokenizer,var_threshold, device="cuda", batch_size=32, output_dir='./plots'):
        """
        Initializes the ModelEditor.
        
        Args:
            model: The transformer model to be edited.
            tokenizer: Tokenizer for the model.
            device: Device to run the model on ('cuda' or 'cpu').
            topk: Number of top-k tokens to return for analysis.
            batch_size: Batch size for processing data.
            output_dir: Directory to save the generated plots.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.device = self.model.device
        self.batch_size = batch_size
        self.var_threshold=var_threshold
        self.human_like_spaces = {}  # To store human-like spaces for each layer
        self.output_dir = output_dir  # Directory to save plots
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        os.makedirs(self.output_dir, exist_ok=True)
    
    def reload_model(self,model):
        self.model = model
        self.device = self.model.device

    def _get_lm_head(self):
        """
        Retrieves the language model head layer from the model.
        """
        return self.model.lm_head if hasattr(self.model, 'lm_head') else None

    def _get_all_layers_hidden_states(self, texts, start_layer, end_layer):
        """
        Get hidden states for all layers at once (last token only).
        
        Args:
            texts: List of texts to process.
            start_layer: Starting layer index.
            end_layer: Ending layer index.
        
        Returns:
            A dictionary containing hidden states for each layer.
        """
        hidden_states_dict = {layer: [] for layer in range(start_layer, end_layer)}
        
        for i in tqdm(range(0, len(texts), self.batch_size), desc="Processing batches"):
            batch_texts = texts[i:i + self.batch_size]
            inputs = self.tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)
            
            # Get all hidden states
            all_hidden_states = outputs.hidden_states
            
            # Store last token hidden states for each requested layer
            for layer in range(start_layer, end_layer):
                if layer < len(all_hidden_states):
                    hidden = all_hidden_states[layer]
                    last_hidden = hidden[:, -1, :]  # take last token
                    hidden_states_dict[layer].append(last_hidden.detach().cpu())
        
            del outputs
            torch.cuda.empty_cache()
        
        # Concatenate all batches for each layer
        for layer in hidden_states_dict:
            hidden_states_dict[layer] = torch.cat(hidden_states_dict[layer], dim=0)
        
        return hidden_states_dict

    def _project_to_vocab_space(self, hidden_vectors):
        """
            Project hidden vectors to the vocabulary space using the LM head.
        """
        lm_head = self._get_lm_head()
        if lm_head is None:
            raise ValueError("The model does not have an lm_head.")
    
    
        logits = lm_head(hidden_vectors)
        return logits

    def _get_top_k_tokens(self, logits, k=20):
        """
        Get the top-k tokens from the logits based on their probabilities.
        
        Args:
            logits: Logits corresponding to the vocabulary.
            k: Number of top tokens to return.
        
        Returns:
            top_k_tokens: List of top-k token strings.
            top_k_values: Corresponding probabilities for the top-k tokens.
        """
        probs = torch.softmax(logits, dim=-1)  # Apply softmax to get probabilities
        top_k_values, top_k_indices = torch.topk(probs, k, dim=-1)  # Get top-k tokens
        top_k_tokens = self.tokenizer.convert_ids_to_tokens(top_k_indices[0].tolist())  # Convert token ids to string
        return top_k_tokens, top_k_values[0].tolist()

    def project_and_plot(self, human_vecs, machine_vecs, layer, top_k=20):
        """
        Project human and machine hidden vectors to vocabulary space and plot the top-k tokens.
        
        Args:
            human_vecs: Hidden states for human texts.
            machine_vecs: Hidden states for machine-generated texts.
            layer: The layer of the model for analysis.
            top_k: Number of top-k tokens to display.
        
        Returns:
            Human and machine top-k tokens and their probabilities.
        """
        hidden_states = torch.cat([human_vecs, machine_vecs], dim=0)
        logits = self._project_to_vocab_space(hidden_states)  # Project to vocabulary space
        
        # Separate the logits for human and machine texts
        human_logits = logits[:len(human_vecs)]
        machine_logits = logits[len(human_vecs):]

        # Get top-k tokens for both human and machine texts
        human_top_k_tokens, human_top_k_values = self._get_top_k_tokens(human_logits, top_k)
        machine_top_k_tokens, machine_top_k_values = self._get_top_k_tokens(machine_logits, top_k)
        print(human_top_k_tokens,machine_top_k_tokens)
        # Plot the results
        self._plot_top_k_tokens(human_top_k_tokens, human_top_k_values, "Human Text", layer)
        self._plot_top_k_tokens(machine_top_k_tokens, machine_top_k_values, "Machine Text", layer)
        
        return human_top_k_tokens, human_top_k_values, machine_top_k_tokens, machine_top_k_values

    def project_and_plot(self, human_vecs, machine_vecs, layer, top_k=20, max_sentences=5):
        """
        Project human and machine hidden vectors to vocabulary space and plot 
        the top-k tokens for the first `max_sentences` sentences of each type.

        Args:
            human_vecs: Tensor [num_human, hidden_dim], hidden states for human texts.
            machine_vecs: Tensor [num_machine, hidden_dim], hidden states for machine texts.
            layer: The layer index of the model for analysis.
            top_k: Number of top tokens to display for each sentence.
            max_sentences: Number of sentences (per type) to plot.
        """
        # only take first N sentences
        human_vecs = human_vecs[:max_sentences]
        machine_vecs = machine_vecs[:max_sentences]

        # concatenate and project
        hidden_states = torch.cat([human_vecs, machine_vecs], dim=0)
        logits = self._project_to_vocab_space(hidden_states)

        # split back
        human_logits = logits[:len(human_vecs)]
        machine_logits = logits[len(human_vecs):]

        # prepare plot: 2 rows (Human/Machine), max_sentences columns
        fig, axes = plt.subplots(2, max_sentences, figsize=(4*max_sentences, 8), sharey=True)
        if max_sentences == 1:
            axes = axes.reshape(2, 1)  # ensure consistent shape

        # plot human sentences
        for i in range(len(human_vecs)):
            tokens, values = self._get_top_k_tokens(human_logits[i].unsqueeze(0), top_k)
            ax = axes[0, i]
            ax.barh(tokens, values)
            ax.set_title(f"Human {i+1}")
            ax.invert_yaxis()

        # plot machine sentences
        for i in range(len(machine_vecs)):
            tokens, values = self._get_top_k_tokens(machine_logits[i].unsqueeze(0), top_k)
            ax = axes[1, i]
            ax.barh(tokens, values)
            ax.set_title(f"Machine {i+1}")
            ax.invert_yaxis()

        # global title
        fig.suptitle(f"Top-{top_k} tokens per sentence (Layer {layer})", fontsize=14)
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        # save figure
        plot_filename = os.path.join(
            self.output_dir, f"layer_{layer}_top_{top_k}_per_sentence.png"
        )
        plt.savefig(plot_filename)
        plt.close(fig)
        print(f"Saved plot for layer {layer} to {plot_filename}")

    
    def plot_mean_tokens(self, human_vecs, machine_vecs, layer, top_k=8):
        """
        Plot top-k tokens for the mean representation of human and machine texts.

        Args:
            human_vecs: Tensor [num_human, hidden_dim], hidden states for human texts.
            machine_vecs: Tensor [num_machine, hidden_dim], hidden states for machine texts.
            layer: The layer index.
            top_k: Number of top tokens to display.
        """
        # Compute mean hidden states
        human_mean = human_vecs.mean(dim=0, keepdim=True)
        machine_mean = machine_vecs.mean(dim=0, keepdim=True)

        # Project to vocabulary space
        logits_h = self._project_to_vocab_space(human_mean)
        logits_m = self._project_to_vocab_space(machine_mean)

        # Get top-k tokens
        tokens_h, values_h = self._get_top_k_tokens(logits_h, top_k)
        tokens_m, values_m = self._get_top_k_tokens(logits_m, top_k)

        # Plot
        fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharey=True)

        axes[0].barh(tokens_h, values_h)
        axes[0].invert_yaxis()
        axes[0].set_title("Human Mean Representation")

        axes[1].barh(tokens_m, values_m)
        axes[1].invert_yaxis()
        axes[1].set_title("Machine Mean Representation")

        fig.suptitle(f"Top-{top_k} tokens for mean representations (Layer {layer})", fontsize=14)
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        # Save figure
        plot_filename = os.path.join(self.output_dir, f"layer_{layer}_mean_top_{top_k}.png")
        plt.savefig(plot_filename)
        plt.close(fig)
        print(f"Saved mean plot for layer {layer} to {plot_filename}")
    
    def get_preference_matrix(self, human_hidden_states, machine_hidden_states, layer):
        """
        Compute preference matrix for a specific layer.
        """
        human_hidden_states = human_hidden_states.to(self.device)
        machine_hidden_states = machine_hidden_states.to(self.device)
        diff = machine_hidden_states - human_hidden_states
        human_mean = human_hidden_states.mean(dim=0, keepdim=True)
        
        # More stable computation
        denominator = human_mean @ human_mean.T
        if denominator.abs() < 1e-10:
            proj = torch.zeros_like(diff)
        else:
            proj = (diff @ human_mean.T) / denominator * human_mean
        
        centered_diff = diff - proj
        
        # Check for numerical issues
        if torch.isnan(centered_diff).any() or torch.isinf(centered_diff).any():
            print(f"Warning: NaN or INF detected in centered_diff for layer {layer}")
            centered_diff = torch.nan_to_num(centered_diff, nan=0.0, posinf=1e6, neginf=-1e6)
        
        return centered_diff.cpu()

    def get_human_like_space(self, centered_diff, layer):
        """
        Compute human-like space with robust SVD.
        """
        if centered_diff.device != self.device:
            centered_diff = centered_diff.to(self.device)
        try:
            # First try on GPU with double precision
            centered_diff_double = centered_diff.double()
            U, S, Vh = torch.linalg.svd(centered_diff_double, full_matrices=False)
        except torch._C._LinAlgError:
            # Fallback to CPU if GPU fails
            print(f"SVD failed on GPU for layer {layer}, trying CPU...")
            centered_diff_cpu = centered_diff.cpu().double()
            U, S, Vh = torch.linalg.svd(centered_diff_cpu, full_matrices=False)
            U, S, Vh = U.to(self.device), S.to(self.device), Vh.to(self.device)
        
        S_squared = S ** 2
        total_variance = torch.sum(S_squared)
        
        if total_variance <= 0 or torch.isnan(total_variance):
            print(f"Warning: Invalid total variance for layer {layer}, using identity basis")
            k = min(centered_diff.shape[1], 10)
            basis = torch.eye(centered_diff.shape[1], device=self.device)[:k]
        else:
            cumulative_variance = torch.cumsum(S_squared, dim=0) / total_variance
            k = torch.sum(cumulative_variance <= self.var_threshold).item() + 1
            k = min(k, centered_diff.shape[1])
            basis = Vh[:k]
        
        print(f'Layer {layer}: basis shape {basis.shape}')
        self.human_like_spaces[layer] = basis
        return basis.cpu()
    
    def load_human_like_space(self,path):
        self.human_like_spaces=torch.load(path)

    def model_edit(self, layer):
        if layer not in self.human_like_spaces:
            raise ValueError("Call difference_space first.")
    
        basis = self.human_like_spaces[layer]

        if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
            model_layer = self.model.transformer.h[layer]
        elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
            model_layer = self.model.model.layers[layer]
        elif hasattr(self.model, 'layers'):
            model_layer = self.model.layers[layer]

        if not hasattr(model_layer, "mlp"):
            print(f"Layer {layer} has no 'mlp', skipping edit.")
            return

        mlp = model_layer.mlp

        weight = None
        if hasattr(mlp, "up_proj"):  
            weight = mlp.up_proj.weight.data
        elif hasattr(mlp, "fc2"): 
            weight = mlp.fc2.weight.data
        elif hasattr(mlp, "w2"): 
            weight = mlp.w2.weight.data
        else:
            print(f"Layer {layer}.mlp has no 'up_proj' or 'fc2', skipping edit.")
            return
        weight = weight.to(torch.float32)
        basis = basis.to(weight.device,dtype=torch.float32)
        proj = weight @ basis.T @ basis
        weight -= proj
        print(f"Edited model weights at layer {layer}.")

    def rep_edit(self, hidden_states, layer,alpha=0.5):
        # print(self.human_like_spaces)
        # if layer not in self.human_like_spaces:
        #     raise ValueError("Human-like space for this layer not loaded")
        basis = self.human_like_spaces[layer]
        basis = basis.to(hidden_states.device,dtype=torch.float32)
        # print('alpha:',alpha,'hidden_states:',hidden_states.shape,'basis:',basis.shape)

        proj =np.float32(alpha) * (hidden_states @ basis.T @ basis)
        return hidden_states - proj


# Load the dataset for analysis
def load_data(path, num_samples=None):
    """
    Load data from a specified path.
    
    Args:
        path: Path to the dataset file.
        num_samples: Number of samples to load (optional).
    
    Returns:
        human: List of human text samples.
        machine: List of machine-generated text samples.
        prompt: List of prompt samples.
    """
    if path.endswith(".jsonl"):  # Read JSONL files
        human, machine, prompt = [], [], []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    obj = json.loads(line)
                    human.append(obj.get("chosen", ""))
                    machine.append(obj.get("rejected", ""))
                    prompt.append(obj.get("prompt", ""))
    else:
        data = load_from_disk(path)
        human = data['chosen']
        machine = data['rejected']
        prompt = data['prompt']
    
    if num_samples:
        return human[:num_samples], machine[:num_samples], prompt[:num_samples]
    else:
        return human, machine, prompt

def main(args):
    train_path = os.path.join(args.data_dir, 'train')
    print("Loading data...")
    human_texts, machine_texts, _ = load_data(train_path, num_samples=args.num_samples)
    print(f"Data loaded: human_texts={len(human_texts)}, machine_texts={len(machine_texts)}\n**Example**\n", 
          human_texts[0].replace("\n", ' '), machine_texts[0].replace("\n", ' '))

    model_path = MODEL_PATH_MAP[args.model_name]
    print('Loading model...')
    tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto',trust_remote_code=True)
    # model = model.to(torch.float32)
    tokenizer.pad_token = tokenizer.eos_token
    # tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    edited_model_path = os.path.join(args.output_model_dir, args.model_name)
    space_path = os.path.join(args.human_like_space_dir, args.model_name)
    os.makedirs(edited_model_path, exist_ok=True)
    os.makedirs(space_path, exist_ok=True)
    os.makedirs(args.Time_path, exist_ok=True)
    time_save_path=os.path.join(args.Time_path,'Editor_Time.txt')

    if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
        total_layers = len(model.transformer.h)
        print(f"Using transformer.h, total layers: {total_layers}")
    elif hasattr(model, 'model') and hasattr(model.model, 'layers'):
        total_layers = len(model.model.layers)
        print(f"Using model.layers, total layers: {total_layers}")
    elif hasattr(model, 'layers'):
        total_layers = len(model.layers)
        print(f"Using layers, total layers: {total_layers}")
    else:
        print("Cannot determine number of layers, model structure:", model.__class__.__name__)
    end_layer = args.start_layer + args.num_edit_layers
    if end_layer > total_layers:
        raise ValueError(f"Requested layers [{args.start_layer}, {end_layer}) exceed total layers {total_layers}")

    print("=== Model Layers ===")
    for name, module in model.named_modules():
        print(name, type(module))

    editor = ModelEditor(model, tokenizer, var_threshold=args.var_threshold, batch_size=args.batch_size, output_dir=args.output_plots_dir)

    print('======== Start Editing =========')
    space_path = os.path.join(space_path, 'human_like_space.pt')
    # editor.load_human_like_space(space_path)
    st = time.time()
    human_hidden_dict = editor._get_all_layers_hidden_states(human_texts, args.start_layer, end_layer)
    machine_hidden_dict = editor._get_all_layers_hidden_states(machine_texts, args.start_layer, end_layer)

    all_human_like_spaces = {}
    
    for layer in range(args.start_layer, end_layer):
        layer_st = time.time()
        print(f"**Processing Layer {layer}**")
        
        human_hidden = human_hidden_dict[layer]
        machine_hidden = machine_hidden_dict[layer]
        
        print('Getting preference matrix...')
        centered_diff = editor.get_preference_matrix(human_hidden, machine_hidden, layer)
        
        print('Getting human-like space...')
        basis = editor.get_human_like_space(centered_diff, layer)
        all_human_like_spaces[layer]=basis
        
        layer_et = time.time()
        print(f"Layer {layer} processed in {layer_et - layer_st:.2f} seconds\n")

    # torch.save(all_human_like_spaces, space_path)
    # print(f"Saved all human-like spaces to {space_path}")

        # Log the time taken for constructing the space for this layer
        with open(time_save_path, "a", encoding="utf-8") as f:
            text = f'{args.data_dir}: Constructed space for {args.model_name} on layer {layer} using {args.num_samples} samples in {layer_et- layer_st} seconds.\n'
            f.write(text)

    # Save all human-like spaces after all layers have been processed
    torch.save(all_human_like_spaces, space_path)
    print(f"Saved all human-like spaces to {space_path}")

    if args.model_edit:
        me_st = time.time()
        print('Editing model weights...')

        # Edit the model weights for each layer based on the human-like space
        for layer in range(args.start_layer, end_layer):
            print(f"Editing layer {layer}...")
            editor.model_edit(layer)

        me_et = time.time()

        # Log the time taken for model editing
        with open(time_save_path, "a", encoding="utf-8") as f:
            text = f'{args.data_dir}: Edited {args.model_name} for layers {args.start_layer}-{end_layer} using {args.num_samples} samples in {me_et - me_st} seconds.\n'
            f.write(text)

    print(f'All layers processed and model edited!')

    if args.model_edit:
        model.half()
        model.save_pretrained(edited_model_path)
        tokenizer.save_pretrained(edited_model_path)
        print(f"Saved edited model to {edited_model_path}")

    # Final time logging
    et = time.time()
    print("Total Edit Time:", et - st)
    with open(time_save_path, "a", encoding="utf-8") as f:
        text = f'Space construction and model editing on {args.model_name} from layer {args.start_layer} to {end_layer} with {args.num_samples} samples(alphs={args.var_threshold}) took {et - st} seconds.\n'
        f.write(text)
    

# Main entry point for script execution
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default='../llm_replication_project/data/dpo_openwebtext_processed/')
    parser.add_argument("--model_name", type=str, default="llama2_13b")
    parser.add_argument("--num_samples", type=int, default=500)
    parser.add_argument("--start_layer", type=int, default=0)
    parser.add_argument("--num_edit_layers", type=int, default=40)
    parser.add_argument("--var_threshold", type=float, default=0.9)
    parser.add_argument("--alpha", type=float, default=1)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--model_edit", action='store_true')  # Whether to apply model edits or not
    parser.add_argument("--output_plots_dir", type=str, default='./plots')  # Directory to save plots
    parser.add_argument("--output_model_dir", type=str, default='./edited_models')  # Directory to save edited model
    parser.add_argument("--human_like_space_dir", type=str, default='./space')  # Directory to save human-like spaces
    parser.add_argument("--Time_path", type=str, default='./Final_res')
    args = parser.parse_args()

    main(args)
