import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import sys

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

import argparse
import gc
from itertools import product
from pathlib import Path
from time import time

import numpy as np
import pandas as pd
import torch
import transformers
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

transformers.logging.set_verbosity_error()
from bert_score import score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils.general import (
    compute_all_token_embedding_grad_emb,
    extract_hidden_states_ids
)
from src.utils.logger import Logging, format_time_minutes
from src.utils.tokenize_dataset import (
    DatasetCollection,
    extract_tokenized_prompts,
    random_tokenized_prompts
)
from src.utils.utils import num_layers, replace_last_norm, set_seed

# TODO: Handle log path and name better
logger = Logging(
    log_path='./logs/', 
    log_name='SIP-It',
    defined={
        'after_backprop': (
            '\rIters [{:,}/{:,}]: Loss: {:.2e} - Gradient norm: {:.2e} - Emb. Norm: {:.2e} - '
            'Accuracy: {:.2f} - Time: {}'
        )
    }
)

def find_prompt(
    model, tokenizer, layer_idx, h_target, true_ids,
    lr, scheduler: bool = False,
):


    embedding_matrix = model.get_input_embeddings().weight

    if h_target.dim() == 1:
        h_target = h_target.unsqueeze(0)

    n_tokens = h_target.size(0)
    token_ids = torch.randint(0, embedding_matrix.size(0), (n_tokens,))
    embeddings = embedding_matrix.clone().detach()[token_ids].requires_grad_(True)
    temp_embeddings = embedding_matrix[token_ids].clone().detach().requires_grad_(False)

    optimizer = torch.optim.SGD([embeddings], lr=lr)
    if scheduler:
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.99, threshold=lr / 100, patience=200)

    max_iters = embedding_matrix.size(0) * h_target.size(0) // 4
    final_timestep = max_iters

    start_time = time()    
    for i in range(max_iters):
        grad_oracle, loss = compute_all_token_embedding_grad_emb(
            temp_embeddings,
            model=model,
            layer_idx=layer_idx, 
            h_target=h_target
        )

        if torch.isnan(loss) or torch.isnan(grad_oracle).any():
            return [None] * 3

        grad_norm = grad_oracle.norm().item()
        
        logger.after_backprop(
            i + 1, 
            max_iters,
            loss.item(), 
            grad_norm, 
            max([embedding.norm().item() for embedding in embeddings]),
            sum([int((x == y).item()) for x, y in zip(token_ids, true_ids[0])]) / len(token_ids),
            format_time_minutes(time() - start_time)
        )

        if loss.item() < 1e-5 or not baseline and grad_norm < 1e-12:
            final_timestep = i + 1
            break

        if grad_norm > 1:
            grad_oracle = grad_oracle / grad_norm

        embeddings.grad = grad_oracle

        optimizer.step(lambda : loss)
        if scheduler:
            scheduler.step(loss)

        token_ids = [
            int(torch.argmin(
                torch.norm(embedding_matrix - x, dim=1)
            ))
            for x in embeddings
        ]
        temp_embeddings = embedding_matrix[token_ids].clone().detach().requires_grad_(False)

        if (i + 1) % 100 == 0:
            embeddings.data = temp_embeddings.data

    logger.new_line()

    token_ids = [
        int(torch.argmin(
            torch.norm(embedding_matrix - x, dim=1)
        ))
        for x in embeddings
    ]

    final_string = tokenizer.decode(token_ids, skip_special_tokens=True)

    return time() - start_time, final_string, final_timestep, token_ids




def inversion_attack(
    input_ids: torch.Tensor, 
    model: AutoModelForCausalLM, 
    tokenizer: AutoTokenizer, 
    layer_idx: int,
    lr: float,
    seed: int = 8, 
    scheduler: bool = False,
    baseline: bool = False
):
    
    set_seed(seed)
    h_target = extract_hidden_states_ids(input_ids, model, layer_idx)

    invertion_time, final_string, iters, discovered_ids = find_prompt(
        model, tokenizer, layer_idx, h_target, input_ids,
        lr, scheduler
    )

    if discovered_ids is None:
        print('Inversion failed or diverged with the given parameters.')
        return False, None, None, None, None

    match = all([x == y for x, y in zip(input_ids[0], discovered_ids)])
    print(f'Original {"==" if match else "!="} Reconstructed')
    print(f'Invertion time: {invertion_time:.2f} seconds')

    return match, invertion_time, final_string, discovered_ids, iters



def parse_args():
    parser = argparse.ArgumentParser(description='Run inversion attack with given configuration.')

    parser.add_argument(
        '-i', '--input', 
        type=str, required=True,
        help='Path to the dataset directory. If it does not exists, a new dataset will be created and saved there.'
    )
    parser.add_argument(
        '-o', '--output', 
        type=str, required=True,
        help='Path to the output CSV file.'
    )
    parser.add_argument(
        '--seed', 
        type=int, default=8,
        help='Random seed to use.'
    )
    parser.add_argument(
        '-n', '--max-prompts', 
        type=int, default=1000,
        help='Maximum amount of prompts to use.'
    )
    parser.add_argument(
        '--step',
        type=int, default=20,
        help='Step for token length.'
    )
    parser.add_argument(
        '--total-lengths',
        type=int, default=10,
        help='Count of different lengths to include.'
    )
    parser.add_argument(
        '--id', '--model-id',
        type=str, default='roneneldan/TinyStories-1M',
        help='Name of HF model to use.'
    )
    parser.add_argument(
        '--learning-rate', 
        type=float, default=1.0,
        help='Learning rate (step sizes) to use.'
    )
    parser.add_argument(
        '--scheduler',
        action='store_true',
        help='Flag for whether to employ a ReduceOnPlateu LR Scheduler'
    )
    parser.add_argument(
        '--baseline',
        action='store_true',
        help='Flag for whether to use the random search algorithm'
    )
    parser.add_argument(
        '--baseline-max-len', 
        type=int, default=20,
        help='Max prompt length when doing exhaustive search.'
    )
    parser.add_argument(
        '--layer', 
        type=int, default=-1,
        help='Model layer to invert.'
    )
    parser.add_argument(
        '--rank', 
        type=int, default=-1,
        help='Rank of process w.r.t. the indices to process.'
    )
    parser.add_argument(
        '--pct', 
        type=float, default=0.0,
        help='Percentage of the dataset to process.'
    )


    return parser.parse_args()


if __name__ == '__main__':

    # TODO: Type hint functions

    args = parse_args()

    model_id = args.id

    input_path = Path(args.input)
    output_dir = Path(args.output)

    seed = args.seed
    n = args.max_prompts
    step = args.step
    lengths = args.total_lengths
    
    learning_rate = args.learning_rate
    scheduler = args.scheduler
    baseline = args.baseline
    baseline_max_len = args.baseline_max_len

    layer = args.layer

    rank = args.rank
    pct  = args.pct

    csv_name = 'results' if rank < 0 or pct <= 0 else f'results-{rank}-{pct:.2f}'
    csv_name = csv_name if baseline else f'{csv_name}-exhaustive-{baseline_max_len}'
    output_file = output_dir / f'{csv_name}.csv'
    output_dir.mkdir(parents=True, exist_ok=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float32,
        device_map=device,
        trust_remote_code=True
    )

    total_layers = num_layers(model_id)
    if layer < 0:
        layer = total_layers + layer + 1

    replace_last_norm(model_id, model)

    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)
    torch.set_grad_enabled(True)
    tokenizer.pad_token = tokenizer.eos_token

    if input_path.exists():
        datasets = DatasetCollection.load(input_path)
    else:
        random_data_samples = int(0.1 * n)
        meaningful_data_samples = n - random_data_samples

        tokenized_datasets = [
            extract_tokenized_prompts(
                dataset_path=Path('data/wikipedia'),
                text_column='text',
                tokenizer=tokenizer,
                prompt_tokens=step * (i + 1),
                max_prompts=meaningful_data_samples // lengths,
                batch_size=64,
                seed=seed
            )
            for i in range(lengths)
        ] + [
            random_tokenized_prompts(
                tokenizer=tokenizer,
                prompt_tokens=step * (i + 1),
                max_prompts=random_data_samples // lengths,
                seed=seed
            )
            for i in range(lengths)
        ]
        datasets = DatasetCollection('sip-it-data', tokenized_datasets)
        datasets.save(input_path)
        
    if rank >= 0 and pct > 0:
        i1 = int(pct * len(datasets.datasets[0]) * rank)
        i2 = int(pct * len(datasets.datasets[0]) * (rank + 1))
        print(f'Processing indices [{i1}, {i2}]')

    write_header = True
    results = []
    for dataset_idx, dataset in enumerate(datasets):
        
        if dataset.prompt_tokens > baseline_max_len:
            continue

        for sample_idx, token_ids in enumerate(dataset):
            if (
                rank >= 0 and 
                pct > 0 and 
                sample_idx not in range(
                    int(pct * len(dataset) * rank), 
                    int(pct * len(dataset) * (rank + 1))
                )
            ):
                continue

            print(f'Dataset: {dataset_idx + 1}, Sample: {sample_idx + 1}, Length: {len(token_ids)}') # type: ignore

            match, invertion_time, recovered_prompt, discovered_ids, iters = inversion_attack(
                token_ids.unsqueeze(0).to(device), # type: ignore 
                model, tokenizer, layer,
                learning_rate, seed, scheduler, baseline
            )

            real_prompt = tokenizer.decode(token_ids, skip_special_tokens=True)

            reference = [real_prompt.split()]
            candidate = recovered_prompt.split()

            smooth_fn = SmoothingFunction().method1
            bleu_score = sentence_bleu(reference, candidate, smoothing_function=smooth_fn)

            P, R, F1 = score(
                [recovered_prompt], 
                [real_prompt], 
                model_type='microsoft/deberta-v3-base', 
                verbose=False
            )
            bert_precision = P.mean().item()
            bert_recall = R.mean().item()
            bert_f1 = F1.mean().item()

            row = {
                'dataset': datasets.dataset_names[dataset_idx],
                'index': sample_idx,
                'layer': layer,
                'learning_rate': learning_rate,
                'token_length': len(token_ids), # type: ignore
                'match': match,
                'inversion_time': invertion_time,
                'iters': iters,
                'bleu': bleu_score,
                'bert_precision': bert_precision,
                'bert_recall': bert_recall,
                'bert_f1': bert_f1,
            }

            results.append(row)

        partial_df = pd.DataFrame(results)
        partial_df.to_csv(
            output_file,
            mode='w' if write_header else 'a',
            header=write_header,
            index=False
        )
        results = []
        write_header = False