import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import sys

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

import argparse
import gc
from itertools import product
from pathlib import Path
from time import time

import numpy as np
import pandas as pd
import torch
import transformers
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

transformers.logging.set_verbosity_error()
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils.general import (
    extract_hidden_states,
    extract_hidden_states_iterative,
    extract_hidden_states_ids, 
    format_token
)
from src.utils.logger import Logging, format_time_minutes
from src.utils.tokenize_dataset import (
    DatasetCollection,
    extract_tokenized_prompts,
    random_tokenized_prompts
)
from src.utils.utils import (
    num_layers, 
    replace_last_norm,
    set_seed
)

# TODO: Handle log path and name better
logger = Logging(
    log_path='./logs/', 
    log_name='SIP-It',
    defined={
        'after_backprop': '\r{}[{:5d}/{:5d}]: Loss: {:.2e} - Gradient norm: {:.2e} - Token: {:15s} - Emb Norm: {:.2e} - Time: {}'
    }
)

class ExhaustiveOptimizer:
    def __init__(self, *args, **kwargs):
        pass

    def step(self, *args, **kwargs):
        pass


def compute_grad_and_elim(
    embeddings: tuple[torch.Tensor, torch.Tensor],
    model: torch.nn.Module,
    layer_idx: int,
    h_target: torch.Tensor,
) -> tuple[torch.Tensor, float]:
    # Move to device
    device = next(model.parameters()).device


    cont_embeddings = embeddings[0].to(device)
    disc_tokens     = embeddings[1].to(device)
    h_target = h_target.to(device)

    fixed_embs = cont_embeddings.clone().detach()
    last_emb = fixed_embs[:, -1:, :].clone().requires_grad_(True)

    inputs_embeds_cont = torch.cat([fixed_embs[:, :-1, :], last_emb], dim=1)
    outputs = model(
        inputs_embeds=inputs_embeds_cont,
        output_hidden_states=True
    )
    hidden_states = outputs.hidden_states
    h_last_cont = hidden_states[layer_idx][0, -1, :]

    with torch.no_grad():
        outputs = model(
            input_ids=disc_tokens,
            output_hidden_states=True
        )
    hidden_states = outputs.hidden_states
    h_last_disc = hidden_states[layer_idx][0, -1, :].detach()


    # Compute MSE loss for last token
    # loss_cont = torch.nn.functional.mse_loss(h_last_cont, h_target, reduction='sum')
    loss_cont = torch.nn.functional.mse_loss(h_last_cont, h_target, reduction='mean')
    loss_disc = torch.nn.functional.mse_loss(h_last_disc, h_target, reduction='sum')
    
    loss_cont.backward()
    return last_emb.grad.squeeze(0, 1), loss_disc


def find_token(
    token_idx,
    embedding_matrix,
    discovered_embeddings, discovered_ids,
    model, tokenizer, layer_idx, h_target,
    lr,
    scheduler: bool = False,
    baseline: bool = False,
):
    copy_embedding_matrix = embedding_matrix.clone().detach().requires_grad_(False)

    if baseline:
        perm = torch.randperm(copy_embedding_matrix.size(0))
        copy_embedding_matrix = copy_embedding_matrix[perm]

    token_id = torch.randint(0, embedding_matrix.size(0), (1,)).item()
    # token_id = 4754
    
    embedding = copy_embedding_matrix[token_id].clone().requires_grad_(True)
    temp_embedding = copy_embedding_matrix[token_id].clone().detach()

    optimizer = torch.optim.SGD([embedding], lr=lr) if not baseline else ExhaustiveOptimizer()
    if scheduler:
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.99, threshold=lr / 100, patience=200)

    initial_desc = f'Token [{token_idx + 1:2d}/{h_target.size(0):2d}]'
    final_timestep = embedding_matrix.size(0)
    start_time = time()
    
    for i in range(embedding_matrix.size(0)):
        input_embeddings = torch.stack(
            discovered_embeddings + [embedding]
        ).unsqueeze(0) 
        input_embeddings_disc = torch.tensor(
            discovered_ids + [token_id]
        ).unsqueeze(0) 

        grad_oracle = loss = torch.zeros_like(h_target[token_idx])

        if baseline:
            with torch.no_grad():
                outputs = model(
                    input_ids=input_embeddings_disc.to(device),
                    output_hidden_states=True
                )
                hidden_states = outputs.hidden_states
            
            h_pred = hidden_states[layer_idx][0, -1, :].detach()

            loss = torch.nn.functional.mse_loss(h_pred, h_target[token_idx], reduction='sum')
        else:
            grad_oracle, loss = compute_grad_and_elim(
                (input_embeddings, input_embeddings_disc),
                model=model,
                layer_idx=layer_idx, 
                h_target=h_target[token_idx]
            )

        if torch.isnan(loss) or torch.isnan(grad_oracle).any():
            return [None] * 3

        grad_norm = grad_oracle.norm().item()
        curr_token = tokenizer.decode([token_id], skip_special_tokens=True)
        
        emb_norm = embedding.norm().item()
        logger.after_backprop(
            initial_desc, 
            i + 1, 
            embedding_matrix.size(0),
            loss.item(), 
            grad_norm, 
            format_token(curr_token, length=15), 
            emb_norm,
            format_time_minutes(time() - start_time)
        )

        if loss.item() < 1e-5 or not baseline and grad_norm < 1e-12:
            final_timestep = i + 1
            break

        if grad_norm > 1:
            grad_oracle = grad_oracle / grad_norm

        embedding.grad = grad_oracle

        optimizer.step(lambda : loss)
        if scheduler:
            scheduler.step(loss)


        copy_embedding_matrix[token_id] = float('inf')
        distances = torch.norm(copy_embedding_matrix - embedding, dim=1)
        token_id = int(torch.argmin(distances))
        temp_embedding = copy_embedding_matrix[token_id].clone()

        if not baseline and (i + 1) % 50 == 0:
            embedding.data = temp_embedding.data

    logger.new_line()

    return token_id, copy_embedding_matrix[token_id], final_timestep


def find_prompt(
    model, tokenizer, layer_idx, h_target,
    lr, scheduler: bool = False,
    baseline: bool = False
):
    embedding_matrix = model.get_input_embeddings().weight

    if h_target.dim() == 1:
        h_target = h_target.unsqueeze(0)

    discovered_embeddings = []
    discovered_ids        = []
    timesteps             = []
    times                 = []

    start_time = time()
    for i in range(h_target.size(0)):
        token_start_time = time()

        next_token_id, next_token_embedding, final_timestep = find_token(
            i, embedding_matrix, 
            discovered_embeddings, discovered_ids,
            model, tokenizer, layer_idx, h_target,
            lr, scheduler, baseline
        )

        token_end_time = time()

        if next_token_embedding is None:
            return [None] * 4


        discovered_embeddings.append(next_token_embedding)
        discovered_ids.append(next_token_id)
        timesteps.append(final_timestep)
        times.append(token_end_time - token_start_time)

        gc.collect()
        torch.cuda.empty_cache()

    end_time = time()

    return end_time - start_time, discovered_ids, timesteps, times


def inversion_attack(
    input_ids: torch.Tensor, 
    model: AutoModelForCausalLM, 
    tokenizer: AutoTokenizer, 
    layer_idx: int,
    lr: float,
    seed: int = 8, 
    scheduler: bool = False,
    baseline: bool = False
):
    
    set_seed(seed)
    # h_target = extract_hidden_states_ids(input_ids, model, layer_idx)
    h_target = extract_hidden_states_iterative(input_ids, model, layer_idx)

    invertion_time, discovered_ids, timesteps, times = find_prompt(
        model, tokenizer, layer_idx, h_target, 
        lr, scheduler, baseline
    )

    if discovered_ids is None:
        print('Inversion failed or diverged with the given parameters.')
        return False, None, None, None

    match = all([x == y for x, y in zip(input_ids[0], discovered_ids)])
    print(f'Original {"==" if match else "!="} Reconstructed')
    print(f'Invertion time: {invertion_time:.2f} seconds')
    print(f'Average Timesteps: {np.mean(timesteps):.2f}')

    return match, invertion_time, timesteps, times



def parse_args():
    parser = argparse.ArgumentParser(description='Run inversion attack with given configuration.')

    parser.add_argument(
        '-i', '--input', 
        type=str, required=True,
        help='Path to the dataset directory. If it does not exists, a new dataset will be created and saved there.'
    )
    parser.add_argument(
        '-o', '--output', 
        type=str, required=True,
        help='Path to the output CSV file.'
    )
    parser.add_argument(
        '--seed', 
        type=int, default=8,
        help='Random seed to use.'
    )
    parser.add_argument(
        '-n', '--max-prompts', 
        type=int, default=1000,
        help='Maximum amount of prompts to use.'
    )
    parser.add_argument(
        '--step',
        type=int, default=20,
        help='Step for token length.'
    )
    parser.add_argument(
        '--total-lengths',
        type=int, default=10,
        help='Count of different lengths to include.'
    )
    parser.add_argument(
        '--id', '--model-id',
        type=str, default='roneneldan/TinyStories-1M',
        help='Name of HF model to use.'
    )
    parser.add_argument(
        '--learning-rate', 
        type=float, default=1.0,
        help='Learning rate (step sizes) to use.'
    )
    parser.add_argument(
        '--scheduler',
        action='store_true',
        help='Flag for whether to employ a ReduceOnPlateu LR Scheduler'
    )
    parser.add_argument(
        '--baseline',
        action='store_true',
        help='Flag for whether to use the random search algorithm'
    )
    parser.add_argument(
        '--baseline-max-len', 
        type=int, default=20,
        help='Max prompt length when doing exhaustive search.'
    )
    parser.add_argument(
        '--layer', 
        type=int, default=-1,
        help='Model layer to invert.'
    )
    parser.add_argument(
        '--rank', 
        type=int, default=-1,
        help='Rank of process w.r.t. the indices to process.'
    )
    parser.add_argument(
        '--pct', 
        type=float, default=0.0,
        help='Percentage of the dataset to process.'
    )


    return parser.parse_args()


if __name__ == '__main__':

    # TODO: Type hint functions

    args = parse_args()

    model_id = args.id

    input_path = Path(args.input)
    output_dir = Path(args.output)

    seed = args.seed
    n = args.max_prompts
    step = args.step
    lengths = args.total_lengths
    
    learning_rate = args.learning_rate
    scheduler = args.scheduler
    baseline = args.baseline
    baseline_max_len = args.baseline_max_len

    layer = args.layer

    rank = args.rank
    pct  = args.pct

    csv_name = 'results' if rank < 0 or pct <= 0 else f'results-{rank}-{pct:.2f}'
    csv_name = csv_name if baseline else f'{csv_name}-exhaustive-{baseline_max_len}'
    output_file = output_dir / f'{csv_name}.csv'
    output_dir.mkdir(parents=True, exist_ok=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float32,
        device_map=device,
        trust_remote_code=True
    )

    total_layers = num_layers(model_id)
    if layer < 0:
        layer = total_layers + layer + 1

    replace_last_norm(model_id, model)

    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)
    torch.set_grad_enabled(True)
    tokenizer.pad_token = tokenizer.eos_token

    if input_path.exists():
        datasets = DatasetCollection.load(input_path)
    else:
        random_data_samples = int(0.1 * n)
        meaningful_data_samples = n - random_data_samples

        tokenized_datasets = [
            extract_tokenized_prompts(
                dataset_path=Path('data/wikipedia'),
                text_column='text',
                tokenizer=tokenizer,
                prompt_tokens=step * (i + 1),
                max_prompts=meaningful_data_samples // lengths,
                batch_size=64,
                seed=seed
            )
            for i in range(lengths)
        ] + [
            random_tokenized_prompts(
                tokenizer=tokenizer,
                prompt_tokens=step * (i + 1),
                max_prompts=random_data_samples // lengths,
                seed=seed
            )
            for i in range(lengths)
        ]
        datasets = DatasetCollection('sip-it-data', tokenized_datasets)
        datasets.save(input_path)

    if rank >= 0 and pct > 0:
        i1 = int(pct * len(datasets.datasets[0]) * rank)
        i2 = int(pct * len(datasets.datasets[0]) * (rank + 1))
        print(f'Processing indices [{i1}, {i2}]')

    write_header = True
    results = []
    for dataset_idx, dataset in enumerate(datasets):
        if baseline and dataset.prompt_tokens > baseline_max_len:
            continue

        for sample_idx, token_ids in enumerate(dataset):
            if (
                rank >= 0 and 
                pct > 0 and 
                sample_idx not in range(
                    int(pct * len(dataset) * rank), 
                    int(pct * len(dataset) * (rank + 1))
                )
            ):
                continue

            print(f'Dataset: {dataset_idx + 1}, Sample: {sample_idx + 1}, Length: {len(token_ids)}') # type: ignore

            match, time_taken, timesteps, times = inversion_attack(
                token_ids.unsqueeze(0).to(device), # type: ignore 
                model, tokenizer, layer,
                learning_rate, seed, scheduler, baseline
            )

            row = {
                'dataset': datasets.dataset_names[dataset_idx],
                'index': sample_idx,
                'layer': layer,
                'learning_rate': learning_rate,
                'token_length': len(token_ids), # type: ignore
                'match': match,
                'inversion_time': time_taken if match else -1,
                'timesteps': '_'.join([str(x) for x in timesteps]) if match else '', # type: ignore
                'times': '_'.join([f'{x:.2f}' for x in times]) if match else '' # type: ignore
            }

            results.append(row)

        partial_df = pd.DataFrame(results)
        partial_df.to_csv(
            output_file,
            mode='w' if write_header else 'a',
            header=write_header,
            index=False
        )
        results = []
        write_header = False