import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import sys

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

import argparse
import gc
from itertools import product
from pathlib import Path
from time import time

import numpy as np
import pandas as pd
import torch
import transformers
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

transformers.logging.set_verbosity_error()
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils.general import (
    extract_hidden_states,
    extract_hidden_states_iterative,
    extract_hidden_states_ids, 
    format_token
)
from src.utils.logger import Logging, format_time_minutes
from src.utils.tokenize_dataset import (
    DatasetCollection,
    extract_tokenized_prompts,
    random_tokenized_prompts
)
from src.utils.utils import (
    num_layers, 
    replace_last_norm,
    set_seed
)

# TODO: Handle log path and name better
logger = Logging(
    log_path='./logs/', 
    log_name='SIP-It',
    defined={
        'after_backprop': '\r{}[{:5d}/{:5d}]: Loss: {:.2e} - Gradient norm: {:.2e} - Token: {:15s} - Emb Norm: {:.2e} - Time: {}'
    }
)


def compute_grad_and_elim(
    embeddings: tuple[torch.Tensor, torch.Tensor],
    model: torch.nn.Module,
    layer_idx: int,
    h_target: torch.Tensor,
) -> tuple[torch.Tensor, float]:
    # Move to device
    device = next(model.parameters()).device


    cont_embeddings = embeddings[0].to(device)
    disc_tokens     = embeddings[1].to(device)
    h_target = h_target.to(device)

    fixed_embs = cont_embeddings.clone().detach()
    last_emb = fixed_embs[:, -1:, :].clone().requires_grad_(True)

    inputs_embeds_cont = torch.cat([fixed_embs[:, :-1, :], last_emb], dim=1)
    outputs = model(
        inputs_embeds=inputs_embeds_cont,
        output_hidden_states=True
    )
    hidden_states = outputs.hidden_states
    h_last_cont = hidden_states[layer_idx][0, -1, :]

    with torch.no_grad():
        outputs = model(
            input_ids=disc_tokens,
            output_hidden_states=True
        )
    hidden_states = outputs.hidden_states
    h_last_disc = hidden_states[layer_idx][0, -1, :].detach()


    # Compute MSE loss for last token
    # loss_cont = torch.nn.functional.mse_loss(h_last_cont, h_target, reduction='sum')
    loss_cont = torch.nn.functional.mse_loss(h_last_cont, h_target, reduction='mean')
    loss_disc = torch.nn.functional.mse_loss(h_last_disc, h_target, reduction='sum')
    
    loss_cont.backward()
    return last_emb.grad.squeeze(0, 1), loss_disc


def find_token(
    token_idx,
    embedding_matrix,
    discovered_embeddings, discovered_ids,
    model, tokenizer, layer_idx, h_target,
    lr,
):
    copy_embedding_matrix = embedding_matrix.clone().detach().requires_grad_(False)

    token_id = torch.randint(0, embedding_matrix.size(0), (1,)).item()
    
    embedding = copy_embedding_matrix[token_id].clone().requires_grad_(True)
    temp_embedding = copy_embedding_matrix[token_id].clone().detach()

    optimizer = torch.optim.SGD([embedding], lr=lr)

    initial_desc = f'Token [{token_idx + 1:2d}/{h_target.size(0):2d}]'
    final_timestep = embedding_matrix.size(0)
    start_time = time()
    
    for i in range(embedding_matrix.size(0)):
        input_embeddings = torch.stack(
            discovered_embeddings + [embedding]
        ).unsqueeze(0) 
        input_embeddings_disc = torch.tensor(
            discovered_ids + [token_id]
        ).unsqueeze(0) 

        grad_oracle, loss = compute_grad_and_elim(
            (input_embeddings, input_embeddings_disc),
            model=model,
            layer_idx=layer_idx, 
            h_target=h_target[token_idx]
        )

        if torch.isnan(loss) or torch.isnan(grad_oracle).any():
            return [None] * 3

        grad_norm = grad_oracle.norm().item()
        curr_token = tokenizer.decode([token_id], skip_special_tokens=True)
        
        emb_norm = embedding.norm().item()
        logger.after_backprop(
            initial_desc, 
            i + 1, 
            embedding_matrix.size(0),
            loss.item(), 
            grad_norm, 
            format_token(curr_token, length=15), 
            emb_norm,
            format_time_minutes(time() - start_time)
        )

        if loss.item() < 1e-5 or grad_norm < 1e-12:
            final_timestep = i + 1
            break

        if grad_norm > 1:
            grad_oracle = grad_oracle / grad_norm

        embedding.grad = grad_oracle

        optimizer.step(lambda : loss)

        copy_embedding_matrix[token_id] = float('inf')
        distances = torch.norm(copy_embedding_matrix - embedding, dim=1)
        token_id = int(torch.argmin(distances))
        temp_embedding = copy_embedding_matrix[token_id].clone()

        if (i + 1) % 50 == 0:
            embedding.data = temp_embedding.data

    logger.new_line()

    return token_id, copy_embedding_matrix[token_id], final_timestep


def find_prompt(
    model, tokenizer, layer_idx, h_target, lr
):
    embedding_matrix = model.get_input_embeddings().weight

    if h_target.dim() == 1:
        h_target = h_target.unsqueeze(0)

    discovered_embeddings = []
    discovered_ids        = []
    timesteps             = []
    times                 = []

    start_time = time()
    for i in range(h_target.size(0)):
        token_start_time = time()

        next_token_id, next_token_embedding, final_timestep = find_token(
            i, embedding_matrix, 
            discovered_embeddings, discovered_ids,
            model, tokenizer, layer_idx, h_target,
            lr
        )

        token_end_time = time()

        if next_token_embedding is None:
            return [None] * 4


        discovered_embeddings.append(next_token_embedding)
        discovered_ids.append(next_token_id)
        timesteps.append(final_timestep)
        times.append(token_end_time - token_start_time)

        gc.collect()
        torch.cuda.empty_cache()

    end_time = time()

    return end_time - start_time, discovered_ids, timesteps, times


def inversion_attack(
    input_ids: torch.Tensor, 
    model: AutoModelForCausalLM, 
    tokenizer: AutoTokenizer, 
    layer_idx: int,
    lr: float,
    seed: int = 8
):
    
    set_seed(seed)
    # h_target = extract_hidden_states_ids(input_ids, model, layer_idx)
    h_target = extract_hidden_states_iterative(input_ids, model, layer_idx)

    invertion_time, discovered_ids, timesteps, times = find_prompt(
        model, tokenizer, layer_idx, h_target, 
        lr
    )

    if discovered_ids is None:
        print('Inversion failed or diverged with the given parameters.')
        return False, None, None, None

    match = all([x == y for x, y in zip(input_ids[0], discovered_ids)])
    print(f'Original {"==" if match else "!="} Reconstructed')
    print(f'Invertion time: {invertion_time:.2f} seconds')
    print(f'Average Timesteps: {np.mean(timesteps):.2f}')

    return match, invertion_time, timesteps, times



def parse_args():
    parser = argparse.ArgumentParser(description='Run inversion attack with given configuration.')

    parser.add_argument(
        '-i', '--input', 
        type=str, required=True,
        help='Path to the text file to invert.'
    )
    parser.add_argument(
        '-o', '--output', 
        type=str, required=True,
        help='Path to the stats output file directory.'
    )
    parser.add_argument(
        '--seed', 
        type=int, default=8,
        help='Random seed to use.'
    )
    parser.add_argument(
        '--id', '--model-id',
        type=str, default='roneneldan/TinyStories-1M',
        help='Name of HF model to use.'
    )
    parser.add_argument(
        '--learning-rate', 
        type=float, default=1.0,
        help='Learning rate (step sizes) to use.'
    )
    parser.add_argument(
        '--layer', 
        type=int, default=-1,
        help='Model layer to invert.'
    )
    parser.add_argument(
        '--rank', 
        type=int, default=-1,
        help='Rank of process w.r.t. the indices to process.'
    )


    return parser.parse_args()


if __name__ == '__main__':

    # TODO: Type hint functions

    args = parse_args()

    model_id = args.id

    input_path = Path(args.input)
    output_dir = Path(args.output)

    seed = args.seed
    
    learning_rate = args.learning_rate
    layer = args.layer

    rank = args.rank

    csv_name = f'results-{layer}' + ('' if rank < 0 else str(rank))
    output_file = output_dir / f'{csv_name}.csv'
    output_dir.mkdir(parents=True, exist_ok=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float32,
        device_map=device,
        trust_remote_code=True
    )

    total_layers = num_layers(model_id)
    if layer < 0:
        layer = total_layers + layer + 1

    replace_last_norm(model_id, model)

    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)
    torch.set_grad_enabled(True)
    tokenizer.pad_token = tokenizer.eos_token


    with input_path.open('r') as f:
        long_string = ''.join(f.readlines())

    tokenized_csv = tokenizer(
        long_string, 
        return_tensors="pt",
        add_special_tokens=False,
    )

    input_ids = tokenized_csv['input_ids'][0]

    match, time_taken, timesteps, times = inversion_attack(
        input_ids.unsqueeze(0).to(device), # type: ignore 
        model, tokenizer, layer,
        learning_rate, seed
    )
    
    pd.DataFrame([{
            'dataset': str(input_path),
            'layer': layer,
            'token_length': len(input_ids), # type: ignore
            'match': match,
            'inversion_time': time_taken if match else -1,
            'timesteps': '_'.join([str(x) for x in timesteps]) if match else '', # type: ignore
            'times': '_'.join([f'{x:.2f}' for x in times]) if match else '' # type: ignore
    }]).to_csv(output_file, index=False)