from dataclasses import dataclass, field, asdict, fields
from typing import Optional
from transformers import HfArgumentParser, TrainingArguments, BitsAndBytesConfig
from trl import SFTConfig
from peft import LoraConfig
import os
import json
from accelerate import Accelerator
import torch
from datetime import datetime, timedelta
from opacus.accountants import RDPAccountant


import pdb

# Define and parse arguments.
@dataclass
class FedArguments:
    fed_alg: Optional[str] = field(default="fedavg", metadata={"help": "the algorithm to use"})
    num_rounds: Optional[int] = field(default=500, metadata={"help": "the number of rounds"})
    num_clients: Optional[int] = field(default=2, metadata={"help": "the number of clients"})
    sample_clients: Optional[int] = field(default=2, metadata={"help": "the number of clients to sample"})
    split_strategy: Optional[str] = field(default="iid", metadata={"help": "the split strategy"})
    prox_mu: Optional[float] = field(default=0.01, metadata={"help": "the mu parameter of FedProx"})
    fedopt_tau: Optional[float] = field(default=1e-3, metadata={"help": "the tau parameter of FedAdagrad, FedYogi and FedAdam"})
    fedopt_eta: Optional[float] = field(default=1e-3, metadata={"help": "the global learning rate parameter of FedAdagrad, FedYogi and FedAdam"})
    fedopt_beta1: Optional[float] = field(default=0.9, metadata={"help": "the beta1 parameter of FedYogi and FedAdam"})
    fedopt_beta2: Optional[float] = field(default=0.99, metadata={"help": "the beta2 parameter of FedYogi and FedAdam"})
    omega_rsvd: Optional[int] = field(default=36, metadata={"help": "the desired_projection_dim of random svd in fedavg_lora"})
    save_model_freq: Optional[int] = field(default=50, metadata={"help": "the frequency to save the model. 50 means save every 50 rounds"})
    noniid_degree: Optional[float] = field(default=1, metadata={"help": "alpha_direchlet parameter"})
    use_dp: Optional[bool] = field(default=False, metadata={"help": "use differential privacy guarantees"})
    dp_target_epsilon: Optional[float] = field(default=1, metadata={"help": "differential privacy budget epsilon"})
    dp_target_delta: Optional[float] = field(default=1e-5, metadata={"help": "differential privacy budget delta"})
    dp_sigma: Optional[float] = field(default=0, metadata={"help": "the multiplier of noise in dp"})
    dp_clip_grad_norm: Optional[float] = field(default=1.0, metadata={"help": "max l2 norm of gradient per sample in dp"})

@dataclass
class ScriptArguments:

    model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
    dataset_name: Optional[str] = field(
        default="lucasmccabe-lmi/CodeAlpaca-20k", metadata={"help": "the dataset name"}
    )
    log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"})
    learning_rate: Optional[float] = field(default=2e-5, metadata={"help": "the learning rate"})    # vicuna and alpaca use 2e-5
    batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
    seq_length: Optional[int] = field(default=2048, metadata={"help": "Input sequence length"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )
    load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
    load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
    use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
    trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
    output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
    load_from_check_point: Optional[bool] = field(default=False, metadata={"help": "loading existed checkpoint from the output_dir"})
    checkpoint_dir: Optional[str] = field(default="output", metadata={"help": "the checkpoint dir"})
    peft_lora_r: Optional[int] = field(default=8, metadata={"help": "the r parameter of the LoRA adapters"})
    peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
    logging_steps: Optional[int] = field(default=100, metadata={"help": "the number of logging steps"})
    use_auth_token: Optional[bool] = field(default=False, metadata={"help": "Use HF auth token to access the model"})   # token and use_auth_token cannot be used together
    num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
    max_steps: Optional[int] = field(default=10, metadata={"help": "the number of training steps"})
    save_steps: Optional[int] = field(
        default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"}
    )
    save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
    push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
    hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
    gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Enable gradient checkpointing"})
    template: Optional[str] = field(default="alpaca", metadata={"help": "the template to use"})
    seed: Optional[int] = field(default=2023, metadata={"help": "the seed to use"})
    dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter of DPO"})
    dataset_sample: Optional[int] = field(default=20000, metadata={"help": "the number of samples to use from the dataset"})
    local_data_dir: Optional[str] = field(default=None, metadata={"help": "the local data directory if you want to use downloaded data"})

parser = HfArgumentParser((ScriptArguments, FedArguments))
script_args, fed_args = parser.parse_args_into_dataclasses()

# ===== Define the LoraConfig =====
if script_args.use_peft:
    if script_args.model_name_or_path == "NousResearch/Llama-2-13b-chat-hf":
        peft_config = LoraConfig(
            r=script_args.peft_lora_r,
            lora_alpha=script_args.peft_lora_alpha,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=[
                "down_proj",
                "gate_proj",
                "v_proj",
                "up_proj",
                "k_proj",
                "o_proj",
                "q_proj"
            ],
        )
    else:
        peft_config = LoraConfig(
            r=script_args.peft_lora_r,
            lora_alpha=script_args.peft_lora_alpha,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "v_proj"],
        )

else:
    peft_config = None

# ===== Define the dp config =====
def compute_noise_multiplier(epsilon, delta, steps, sample_rate, sensitivity: float = 1.0, epsilon_tolerance: float = 0.01):
    eps_high = float("inf")
    accountant = RDPAccountant()

    sigma_low, sigma_high = 0, 10

    while eps_high > epsilon:
        sigma_high = 2 * sigma_high # * sensitivity
        accountant.history = [(sigma_high, sample_rate, steps)]
        eps_high = accountant.get_epsilon(delta=delta)

    while epsilon - eps_high > epsilon_tolerance:
        sigma = (sigma_low + sigma_high) / 2 # * sensitivity
        accountant.history = [(sigma, sample_rate, steps)]
        eps = accountant.get_epsilon(delta=delta)
        # pdb.set_trace()
        if eps > epsilon:
            sigma_low = sigma
        else:
            sigma_high = sigma
            eps_high = eps

    return sigma_high  # 返回调整后的sigma


if fed_args.use_dp:
    sampling_rate = fed_args.sample_clients/ fed_args.num_clients * (script_args.batch_size / (script_args.dataset_sample / fed_args.num_clients))
    sensitivity = fed_args.dp_clip_grad_norm
    fed_args.dp_sigma = compute_noise_multiplier(fed_args.dp_target_epsilon, fed_args.dp_target_delta, script_args.max_steps, sampling_rate) # sensitivity

    # print(fed_args.dp_sigma)
    # pdb.set_trace()


def get_config():
    return script_args, fed_args, peft_config

# ===== Define the training arguments =====
def get_training_args(script_args, new_lr):
    training_args = SFTConfig(
        output_dir=script_args.output_dir,
        per_device_train_batch_size=script_args.batch_size,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        learning_rate=new_lr,
        logging_steps=script_args.logging_steps,
        num_train_epochs=script_args.num_train_epochs,
        max_steps=script_args.max_steps,
        report_to=script_args.log_with,
        save_steps=script_args.save_steps,
        save_total_limit=script_args.save_total_limit,
        push_to_hub=script_args.push_to_hub,
        hub_model_id=script_args.hub_model_id,
        gradient_checkpointing=script_args.gradient_checkpointing,
        lr_scheduler_type="constant",
        max_seq_length=script_args.seq_length,
    )
    return training_args

def get_model_config(script_args):
    if script_args.load_in_8bit and script_args.load_in_4bit:
        raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
    elif script_args.load_in_8bit:
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=script_args.load_in_8bit
        )
        # Copy the model to each device
        device_map = {"": Accelerator().local_process_index}
        # device_map = None
        torch_dtype = torch.bfloat16
    elif script_args.load_in_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=script_args.load_in_4bit,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
        # Copy the model to each device
        device_map = {"": Accelerator().local_process_index}
        # device_map = 'auto'
        torch_dtype = torch.bfloat16
    else:
        device_map = None
        quantization_config = None
        torch_dtype = None
    return device_map, quantization_config, torch_dtype

def save_config(script_args, fed_args):
    now_time = (datetime.now()).strftime("%Y%m%d%H%M%S")
    dataset_name_split = os.path.basename(script_args.dataset_name)
    if script_args.load_from_check_point:
        print(script_args.checkpoint_dir)
        output_dir = os.path.dirname(script_args.checkpoint_dir)
    else:
        output_dir = f"{script_args.output_dir}/{dataset_name_split}_{script_args.dataset_sample}_{fed_args.fed_alg}_c{fed_args.num_clients}s{fed_args.sample_clients}_i{script_args.max_steps}_b{script_args.batch_size}a{script_args.gradient_accumulation_steps}_l{script_args.seq_length}_r{script_args.peft_lora_r}a{script_args.peft_lora_alpha}_{now_time}"
        
        while True:
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)
                break
            else:
                now_time = (datetime.now() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S")
                output_dir = f"{script_args.output_dir}/{dataset_name_split}_{fed_args.fed_alg}_c{fed_args.num_clients}s{fed_args.sample_clients}_i{script_args.max_steps}_b{script_args.batch_size}a{script_args.gradient_accumulation_steps}_l{script_args.seq_length}_{now_time}"

    script_args.output_dir = output_dir

    with open(os.path.join(script_args.output_dir, "args.json"), "w") as f:
        combined_dict = {
            "script_args": asdict(script_args),
            "fed_args": asdict(fed_args),
        }
        json.dump(combined_dict, f, indent=4)


def update_dataclass_from_json(dataclass_instance, json_data):
    """
    Update the fields of a dataclass instance with values from a JSON object.
    Only updates fields that are present in both the dataclass and the JSON.
    """
    dataclass_fields = {f.name for f in fields(dataclass_instance)}
    for key, value in json_data.items():
        if key in dataclass_fields:
            setattr(dataclass_instance, key, value)

def load_and_update_config(json_file_path, script_args, fed_args):
    # Load the JSON file
    with open(json_file_path, "r") as f:
        json_data = json.load(f)

    checkpoint_dir = script_args.checkpoint_dir
    # Update script_args and fed_args with the JSON data
    update_dataclass_from_json(script_args, json_data.get("script_args", {}))
    update_dataclass_from_json(fed_args, json_data.get("fed_args", {}))

    script_args.load_from_check_point = True
    script_args.checkpoint_dir = checkpoint_dir

    if script_args.use_peft:
        if script_args.model_name_or_path == "NousResearch/Llama-2-13b-chat-hf":
            peft_config = LoraConfig(
                r=script_args.peft_lora_r,
                lora_alpha=script_args.peft_lora_alpha,
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                target_modules=[
                    "down_proj",
                    "gate_proj",
                    "v_proj",
                    "up_proj",
                    "k_proj",
                    "o_proj",
                    "q_proj"
                ],
            )
        else:
            peft_config = LoraConfig(
                r=script_args.peft_lora_r,
                lora_alpha=script_args.peft_lora_alpha,
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                target_modules=["q_proj", "v_proj"],
            )

    else:
        peft_config = None
    
    return script_args, fed_args, peft_config

    

    