#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch.nn as nn

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
from peft import get_peft_model, LoraConfig, TaskType

import numpy as np

from opacus.validators import ModuleValidator
from opacus.accountants.utils import get_noise_multiplier
from opacus.utils.batch_memory_manager import BatchMemoryManager
from opacus import PrivacyEngine
from opacus.accountants.prv import PRVAccountant

from math import sqrt
from torch.nn.functional import cosine_similarity

import random

import wandb
import argparse
wandb.login()


# In[2]:


def get_config_from_args():
    parser = argparse.ArgumentParser(description="Set config variables via command line arguments.")
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--physical_batch_size', type=int, default=500)
    parser.add_argument('--gas', type=int, default=4)
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--epsilon', type=float, default=6.7)
    parser.add_argument('--delta', type=float, default=1e-5)
    parser.add_argument('--max_grad_norm', type=float, default=10.0)
    parser.add_argument('--max_whole_grad_norm', type=float, default=1.0)
    parser.add_argument('--device', type=str, default="")
    parser.add_argument('--denoise_ratio_threshold', type=float, default=1.02)
    parser.add_argument('--dataset_name', type=str, default="sst2")
    parser.add_argument('--apply_svd_denoising', type=lambda x: (str(x).lower() == 'true'), default=True)
    parser.add_argument('--tot_clipping', type=lambda x: (str(x).lower() == 'true'), default=False)
    parser.add_argument('--model_name', type=str, default="roberta-base")
    parser.add_argument('--correct_norm', type=lambda x: (str(x).lower() == 'true'), default=False)
    parser.add_argument('--num_steps', type=int, default=0)
    parser.add_argument('--evaluation_steps', type=int, default=0)
    parser.add_argument('--noise-multiplier', type=float, default=None, help='If set, overrides epsilon/delta settings and uses this noise multiplier directly.')
    args, unknown = parser.parse_known_args()
    return args

args = get_config_from_args()
epochs = args.epochs
physical_batch_size = args.physical_batch_size
gas = args.gas
batch_size = physical_batch_size * gas
lr = args.lr
epsilon = args.epsilon
delta = args.delta
max_grad_norm = args.max_grad_norm
max_whole_grad_norm = args.max_whole_grad_norm
device = args.device
denoise_ratio_threshold = args.denoise_ratio_threshold
dataset_name = args.dataset_name
apply_svd_denoising = args.apply_svd_denoising
tot_clipping = args.tot_clipping
print("tot_clipping:", tot_clipping)
model_name = args.model_name
correct_norm = args.correct_norm
print("correct_norm:", correct_norm)
num_steps = args.num_steps
evaluation_steps = args.evaluation_steps


# In[ ]:


def get_num_labels(dataset_name):
    if dataset_name in ["sst2", "qqp", "qnli"]:
        return 2
    elif dataset_name in ["mnli"]:
        return 3
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")


# In[ ]:


num_labels = get_num_labels(dataset_name)


# In[5]:


model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained(model_name)


# In[6]:


lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=16,
    lora_alpha=16,
    target_modules=["query", "value", "key","intermediate.dense", "output.dense"]
)
model = get_peft_model(model, lora_config)


# In[ ]:


def get_train_val_splits(dataset_name):
    if dataset_name == "sst2":
        train_split_name = "train"
        valid_split_name = "validation"
    elif dataset_name == "mnli":
        train_split_name = "train"
        valid_split_name = "validation_mismatched"
    elif dataset_name == "qqp":
        train_split_name = "train"
        valid_split_name = "validation"
    elif dataset_name == "qnli":
        train_split_name = "train"
        valid_split_name = "validation"
    train_ds = load_dataset("nyu-mll/glue", dataset_name, split=train_split_name)
    valid_ds = load_dataset("nyu-mll/glue", dataset_name, split=valid_split_name)
    return train_ds, valid_ds


# In[ ]:


def get_tokenizer(dataset_name, model_tokenizer):
    if dataset_name in ["sst2"]:
        text_key = "sentence"
        def preprocess_function(examples):
            return model_tokenizer(examples[text_key], truncation=True, padding="max_length", max_length=128)
    elif dataset_name in ["qqp"]:
        text_key1 = "question1"
        text_key2 = "question2"
        def preprocess_function(examples):
            return model_tokenizer(examples[text_key1], examples[text_key2], truncation=True, padding="max_length", max_length=128)
    elif dataset_name in ["qnli"]:
        text_key1 = "question"
        text_key2 = "sentence"
        def preprocess_function(examples):
            return model_tokenizer(examples[text_key1], examples[text_key2], truncation=True, padding="max_length", max_length=128)
    elif dataset_name in ["mnli"]:
        text_key1 = "premise"
        text_key2 = "hypothesis"
        def preprocess_function(examples):
            return model_tokenizer(examples[text_key1], examples[text_key2], truncation=True, padding="max_length", max_length=128)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    return preprocess_function

def get_redundant_columns(dataset_name):
    if dataset_name in ["sst2"]:
        return ["idx", "sentence"]
    elif dataset_name in ["qqp"]:
        return ["idx", "question1", "question2"]
    elif dataset_name in ["qnli"]:
        return ["idx", "question", "sentence"]
    elif dataset_name in ["mnli"]:
        return ["idx", "premise", "hypothesis"]
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")


# In[9]:


train_ds, valid_ds = get_train_val_splits(dataset_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)

preprocess_function = get_tokenizer(dataset_name, tokenizer)
redundant_columns = get_redundant_columns(dataset_name)

toked_train_ds = train_ds.map(preprocess_function, batched=True)
toked_train_ds = toked_train_ds.remove_columns(column_names=redundant_columns)

toked_valid_ds = valid_ds.map(preprocess_function, batched=True)
toked_valid_ds = toked_valid_ds.remove_columns(column_names=redundant_columns)

dc = DataCollatorWithPadding(tokenizer=tokenizer, max_length=128, padding="max_length")

train_dl = DataLoader(toked_train_ds, collate_fn=dc, batch_size=batch_size, shuffle=True)
valid_dl = DataLoader(toked_valid_ds, collate_fn=dc, batch_size=batch_size)


# In[10]:


model = ModuleValidator.fix(model)


# In[ ]:


sample_rate = batch_size / len(toked_train_ds)
step_limit = epochs * len(train_dl) if num_steps == 0 else num_steps


# In[12]:

if args.noise_multiplier is not None:
    sigma = args.noise_multiplier
    print(f"Using provided noise multiplier: {sigma}")
else:
    sigma=get_noise_multiplier(
        target_epsilon = epsilon,
        target_delta = delta,
        sample_rate = sample_rate,
        steps = step_limit,
        accountant='prv'
    )
print(f"Sigma: {sigma}")
privacy_accountant = PRVAccountant()


# In[13]:


logger = wandb.init(entity="dadsetan", project="svd-dp", tags=["script", 'iclr-abstract'],config={
    "epochs": epochs,
    "batch_size": batch_size,
    "learning_rate": lr,
    "epsilon": epsilon,
    "delta": delta,
    "max_grad_norm": max_grad_norm,
    "model_name": model_name,
    "device": device,
    "sigma": sigma,
    "apply_svd_denoising": apply_svd_denoising,
    "denoise_ratio_threshold": denoise_ratio_threshold,
    "tot_clipping": tot_clipping,
    "max_whole_grad_norm": max_whole_grad_norm,
    "dataset_name": dataset_name,
    "correct_norm": correct_norm,
    "num_steps": num_steps,
    "step_limit": step_limit,
})
print(logger.name)


# In[14]:


optimizer = torch.optim.AdamW(model.parameters(), lr=lr, eps=1e-6, weight_decay=0.01)#, betas=(0.7, 0.99))
# optimizer = torch.optim.SGD(model.parameters(), lr=lr)
model = model.train()



criterion = nn.CrossEntropyLoss(reduction="sum")

privacy_engine = PrivacyEngine()
private_model, private_optimizer, private_criterion, private_train_dataloader = (
    privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=train_dl,
        noise_multiplier=sigma,
        criterion=criterion,
        max_grad_norm=max_grad_norm,
        grad_sample_mode="ghost",
        loss_reduction="sum",
        poisson_sampling=True,
    )
)

private_model = private_model.to(device)
private_model = private_model.train()


# In[ ]:


def find_largest_positive_lambda(target, n, m, sigma):
    """
    Finds the largest positive lambda such that:
        sqrt((lambda + (n*sigma^2)/lambda) * (lambda + (m*sigma^2)/lambda)) = target

    Returns 0 if there is no positive solution.
    """
    ratio = target / (sigma * (sqrt(m) + sqrt(n)))
    if ratio <= 1:
        return 0.0, ratio


    a = 1
    b = (n + m) * sigma**2 - target**2 # less than 0
    c = n * m * sigma**4

    # Solve quadratic in x^2: a*z^2 + b*z + c = 0, z = x^2
    discriminant = b**2 - 4*a*c
    if discriminant < 0:
        return 0, ratio

    sqrt_disc = sqrt(discriminant)
    z1 = (-b + sqrt_disc) / (2*a)

    # We want the largest positive lambda, so take the largest positive root for x^2

    return (sqrt(z1) if z1 > 0 else 0), ratio

def component_sim(lambda_hat, n, m, sigma):
    """
    Computes the component similarity for the given lambda_hat, n, m, and sigma.
    """
    return (lambda_hat**4 - n * m * sigma**4) / (lambda_hat**2 * sqrt((lambda_hat**2 + n * sigma**2) * (lambda_hat**2 + m * sigma**2)))


def svd_shrinkage(noisy, sigma):
    """
    Applies SVD shrinkage to the noisy matrix. Returns the denoised matrix, the similarity score, and the rank of the denoised matrix.
    The similarity score will be used to scale the denoised matrix.
    """
    U, S, V = torch.svd(noisy)
    # m, n = noisy.shape
    n, m = noisy.shape
    lambda_inverse_list = []
    similarities = []
    ratios = []
    for s in S:
        new_s, ratio = find_largest_positive_lambda(s, n, m, sigma)
        ratios.append(ratio)
        if new_s > 0:
            component_similarity = component_sim(new_s, n, m, sigma)

            lambda_inverse_list.append(new_s)
            similarities.append(component_similarity)
            
        # If new_s is 0 or negative, we stop adding components
        # because the rest will also be 0.
        else:
            break
    k = len(lambda_inverse_list)
    if k == 0:
        return torch.zeros_like(noisy, dtype=noisy.dtype, device=noisy.device), k, max(ratios)
    new_singular_tensor = torch.tensor([ss * sim for ss, sim in zip(lambda_inverse_list, similarities)], device=noisy.device)
    biggest_ratio = max(ratios)
    return U[:, :k] @ torch.diag(new_singular_tensor) @ V[:, :k].t(), k, biggest_ratio

def sim(x, y):
    return cosine_similarity(x.flatten(), y.flatten(), dim=0)


# In[ ]:


class StepLimitReached(Exception):
    pass

class EvaluationRequested(Exception):
    pass

def train_loop(model, criterion, optimizer, train_dl, stats=[], log_unprivate=False, step_stats = []):
    global tot_step
    global privacy_accountant
    model.train()
    log_loss = 0.0
    log_batchsize = 0
    with BatchMemoryManager(
        data_loader=train_dl,
        max_physical_batch_size=physical_batch_size,
        optimizer=optimizer,
    ) as memory_safe_train_dl:
        loop = tqdm(memory_safe_train_dl, desc="new Epoch ")
        stop_flag = "continue"
        for batch in loop:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
            loss = criterion(logits, labels)



            loss.backward()
            log_loss += loss.item()
            log_batchsize += input_ids.shape[0]
            
            if optimizer.pre_step():
                tot_step += 1
                with torch.no_grad():
                    denoised_norm_sq = 0.0
                    noisy_norm_sq = 0.0
                    
                    if log_unprivate:
                        signal_norm_sq = 0.0
                        signal_denoised_inner = 0.0
                        signal_noisy_inner = 0.0
                    for name, param in model.named_parameters():
                        if param.requires_grad:
                            noisy_norm_sq += torch.norm(param.grad).item() ** 2
                            if param.grad.ndim == 2:
                                noisy_svs = torch.linalg.svdvals(param.grad)
                                m = param.grad.shape[0]
                                n = param.grad.shape[1]
                                denoised, k, biggest_ratio = svd_shrinkage(param.grad, sigma * optimizer.max_grad_norm)
                                stat = ({
                                    "logical_step": tot_step,
                                    "param": name,
                                    "noisy_svs": noisy_svs.cpu().numpy(),
                                    "n": n,
                                    "m": m,
                                    "k": k,
                                    "max_ratio": noisy_svs[0].item() / (sigma * optimizer.max_grad_norm * (sqrt(n) +  sqrt(m))) if sigma > 0.0 else np.inf,
                                })
                                
                                wandb_log = ({
                                    "logical_step": tot_step,
                                    f"{name}_max_ratio": noisy_svs[0].item() / (sigma * optimizer.max_grad_norm * (sqrt(n) +  sqrt(m))) if sigma > 0.0 else np.inf,
                                })

                                if log_unprivate:
                                    signal_svs = torch.linalg.svdvals(param.summed_grad)
                                    original_sim = sim(param.grad, param.summed_grad)
                                    denoised_sim = sim(denoised, param.summed_grad)
                                    stat.update({
                                        "signal_svs": signal_svs.cpu().numpy(),
                                        "original_sim": original_sim.item(),
                                        "denoised_sim": denoised_sim.item(),
                                    })
                                    
                                    wandb_log.update({
                                        f"{name}_original_sim": original_sim.item(),
                                        f"{name}_denoised_sim": denoised_sim.item(),
                                        f"{name}_sim_improvement": denoised_sim.item() - original_sim.item(),
                                    })
                                    
                                    signal_norm_sq += torch.norm(param.summed_grad).item() ** 2 # this and the other could be merged outside the if ndim == 2
                                    signal_noisy_inner += torch.dot(param.summed_grad.flatten(), param.grad.flatten()).item()
                                if biggest_ratio > denoise_ratio_threshold and apply_svd_denoising:
                                    param.grad = denoised * (torch.norm(param.grad)/torch.norm(denoised) if correct_norm else 1.0)
                                    denoised_norm_sq += torch.norm(param.grad).item() ** 2
                                    stat.update({"denoised": True})
                                    wandb_log.update({
                                        f"{name}_denoised": True,
                                    })
                                    if log_unprivate:
                                        signal_denoised_inner += torch.dot(param.summed_grad.flatten(), param.grad.flatten()).item()

                                else:
                                    denoised_norm_sq += torch.norm(param.grad).item() ** 2
                                    stat.update({"denoised": False})
                                    wandb_log.update({
                                        f"{name}_denoised": False,
                                    })
                                    if log_unprivate:
                                        signal_denoised_inner += torch.dot(param.summed_grad.flatten(), param.grad.flatten()).item()
                                stats.append(stat)
                                logger.log(wandb_log)
                            else:
                                denoised_norm_sq += torch.norm(param.grad).item() ** 2
                                if log_unprivate:
                                    signal_norm_sq += torch.norm(param.summed_grad).item() ** 2 # this and the other could be merged outside the if ndim == 2
                                    inner = torch.dot(param.summed_grad.flatten(), param.grad.flatten()).item()
                                    signal_noisy_inner += inner
                                    signal_denoised_inner += inner
                    if tot_clipping:
                        grad_norm_sq = 0.0
                        for param in model.parameters():
                            if param.requires_grad:
                                param_grad_norm = param.grad.norm().item()
                                grad_norm_sq += param_grad_norm ** 2
                        whole_clip_factor = min(1.0, max_whole_grad_norm / sqrt(grad_norm_sq))
                        with torch.no_grad():
                            for param in model.parameters():
                                if param.requires_grad:
                                    param.grad.mul_(whole_clip_factor)
                optimizer.original_optimizer.step()
                privacy_accountant.step(noise_multiplier=sigma, sample_rate=sample_rate)
                
                noisy_norm = sqrt(noisy_norm_sq)
                denoised_norm = sqrt(denoised_norm_sq)
                
                step_log = ({
                    "loss": log_loss / log_batchsize,
                    "batch_size": log_batchsize,
                    "lr": optimizer.param_groups[0]["lr"],
                    "logical_step": tot_step,
                    "noisy_norm": noisy_norm,
                    "denoised_norm": denoised_norm,
                    "privacy_spent_epsilon": privacy_accountant.get_epsilon(delta) if sigma > 0.0 else 0.0,
                })
                
                if log_unprivate:

                    signal_norm = sqrt(signal_norm_sq)
                    noisy_sim = signal_noisy_inner / (noisy_norm * signal_norm) if noisy_norm > 0 and signal_norm > 0 else 0
                    denoised_sim = signal_denoised_inner / (denoised_norm * signal_norm) if denoised_norm > 0 and signal_norm > 0 else 0

                    step_log.update({
                        "signal_norm": signal_norm,
                        "signal_noisy_inner": signal_noisy_inner,
                        "signal_denoised_inner": signal_denoised_inner,
                        "noisy_sim": noisy_sim,
                        "denoised_sim": denoised_sim,
                        "step_improvement": denoised_sim - noisy_sim,
                    })

                logger.log(step_log)
                step_stats.append(step_log)
                
                if evaluation_steps > 0 and tot_step % evaluation_steps == 0:
                    stop_flag = "evaluation"
                if tot_step == step_limit:
                    stop_flag = "stop"
                log_loss = 0.0
                log_batchsize = 0
            optimizer.zero_grad()
            if stop_flag == "stop":
                raise StepLimitReached("Step limit reached") 
            elif stop_flag == "evaluation":
                raise EvaluationRequested("Evaluation requested")
            loop.set_postfix(loss=loss.item()/input_ids.shape[0], lr=optimizer.param_groups[0]["lr"])


# In[ ]:


def test_loop(model, criterion, test_dl, stats=[]):
    global tot_step
    global privacy_accountant
    with torch.no_grad():
        model.eval()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0
        for batch in tqdm(test_dl):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
            loss = criterion(logits, labels)
            total_loss += loss.item()
            total_correct += (logits.argmax(dim=-1) == labels).sum().item()
            total_samples += labels.size(0)
        privacy_spent_epsilon = privacy_accountant.get_epsilon(delta) if sigma > 0.0 else 0.0
        print(f"loss: {total_loss / total_samples}, accuracy: {total_correct / total_samples}, privacy spent: {privacy_spent_epsilon}")
        stats.append({
            "loss": total_loss / total_samples,
            "accuracy": total_correct / total_samples,
            "privacy_spent_epsilon": privacy_spent_epsilon,
            "logical_step": tot_step,
        })
        logger.log({
            "eval_loss": total_loss / total_samples,
            "eval_accuracy": total_correct / total_samples,
            "privacy_spent_epsilon": privacy_spent_epsilon,
            "validation_step": len(stats),
            "logical_step": tot_step,
        })


# In[ ]:


stats = []
tot_step = 0
step_stats = []
validation_stats = []
while True:
    try:
        try:
            train_loop(private_model, private_criterion, private_optimizer, private_train_dataloader, stats=stats, log_unprivate=True, step_stats=step_stats)
        except EvaluationRequested:
            test_loop(private_model, private_criterion, valid_dl, stats=validation_stats)
        if evaluation_steps == 0:
            test_loop(private_model, private_criterion, valid_dl, stats=validation_stats)
    except StepLimitReached:
        print('step limit reached')
        break
test_loop(private_model, private_criterion, valid_dl, stats=validation_stats)


# In[ ]:


import json
import pickle

filename = f"step_stats_{dataset_name}_{model_name}_{logger.name}_epochs{epochs}_{'denoising' if apply_svd_denoising else 'no_denoising'}_{'totclip' if tot_clipping else 'nototclip'}.json"
with open(filename, "w") as f:
    json.dump(step_stats, f)
print(f"Saved step_stats to {filename}")

val_filename = f"validation_stats_{dataset_name}_{model_name}_{logger.name}_epochs{epochs}_{'denoising' if apply_svd_denoising else 'no_denoising'}_{'totclip' if tot_clipping else 'nototclip'}.json"
with open(val_filename, "w") as f:
    json.dump(validation_stats, f)
print(f"Saved validation_stats to {val_filename}")

stats_filename = f"stats_{dataset_name}_{model_name}_{logger.name}_epochs{epochs}_{'denoising' if apply_svd_denoising else 'no_denoising'}_{'totclip' if tot_clipping else 'nototclip'}.pkl"
with open(stats_filename, "wb") as f:
    pickle.dump(stats, f)
print(f"Saved stats to {stats_filename}")

