import os
import warnings
import argparse

import torch
import numpy as np
from tqdm import tqdm
from scipy import stats
from transformers import AutoTokenizer, AutoModelForCausalLM

from utils import set_seed, auto_or_float, InverseTriangleDistribution


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-pt", "--pretrained_model_path", type=str)
    parser.add_argument("-ft", "--finetuned_model_path", type=str)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--v_change_rate", type=float, default=1)
    parser.add_argument("--distribution", type=str, default="normal", choices=["normal", "uniform", "degenerate", "inverse_triangle", "scaling", "imbalance_scaling", "shuffle", "bin", "gradient_imbalance_scaling"])
    parser.add_argument("--output_dir", type=str, default="./generated_models")
    parser.add_argument("--sign_change_method", type=str, default="inverse", choices=["inverse", "random", "inverse_neg", "shuffle", "inverse_neg_rescale_neg", "masking", "magnitude_drop"])
    parser.add_argument("--sign_change_rate", type=float, default=0)
    parser.add_argument("--unchange_sign_rescale_rate", type=auto_or_float, default=1)
    parser.add_argument("--changed_sign_rescale_rate", type=float, default=1)
    parser.add_argument("--huge_model", action="store_true", default=False)
    temp_args, _ = parser.parse_known_args()

    if temp_args.distribution == 'normal':
        parser.add_argument('--mean', type=auto_or_float, default=0)
        parser.add_argument('--std', type=auto_or_float, required=True, help='Standard deviation of the normal distribution')
    elif temp_args.distribution == 'uniform':
        parser.add_argument('--low', type=auto_or_float, default=0, help='Lower bound of the uniform distribution')
        parser.add_argument('--high', type=auto_or_float, required=True, help='Upper bound of the uniform distribution')
    elif temp_args.distribution == 'degenerate':
        parser.add_argument('--value', type=auto_or_float, required=True, help='Value of the degenerate distribution')
    elif temp_args.distribution == 'inverse_triangle':
        parser.add_argument('--left', type=float, required=True, help='Lower bound of the inverse triangle distribution')
        parser.add_argument('--right', type=float, required=True, help='Upper bound of the inverse triangle distribution')
        parser.add_argument('--peak', type=float, required=True, help='Peak of the inverse triangle distribution')
        parser.add_argument('--left_area', type=float, default=0.5, help='Area of the left triangle')
    elif temp_args.distribution == 'scaling':
        parser.add_argument('--scale', type=float, required=True, help='Scaling factor')
    # elif temp_args.distribution == 'masking':
    #     parser.add_argument('--mask_prob', type=float, required=True, help='Probability of masking a parameter')
    elif temp_args.distribution == 'imbalance_scaling':
        parser.add_argument('--neg_scale', type=float, required=True, help='Scaling factor for negative values')
        parser.add_argument('--pos_scale', type=float, required=True, help='Scaling factor for positive values')
    elif temp_args.distribution == 'gradient_imbalance_scaling':
        parser.add_argument('--neg_scale', type=float, required=True, help='Scaling factor for negative values')
        parser.add_argument('--pos_scale', type=float, required=True, help='Scaling factor for positive values')
        parser.add_argument('--gradient_path', type=str, required=True)
        parser.add_argument('--gradient_name', type=str, required=True)
    elif temp_args.distribution == "bin":
        parser.add_argument('--bins', type=int, required=True, help='Number of bins')
        parser.add_argument('--value_type', type=str, default="mean", choices=["mean", "median"], help='Type of value to use for binning')
    args = parser.parse_args()

    set_seed(args.seed)

    if args.distribution == 'normal':
        distribution_params = f"{args.mean}_{args.std}"
    elif args.distribution == 'uniform':
        distribution_params = f"{args.low}_{args.high}"
    elif args.distribution == 'degenerate':
        distribution_params = f"{args.value}"
    elif args.distribution == 'inverse_triangle':
        distribution_params = f"{args.left}_{args.right}_{args.peak}_{args.left_area}"
    elif args.distribution == 'scaling':
        distribution_params = f"{args.scale}"
    elif args.distribution == 'masking':
        distribution_params = f"{args.mask_prob}"
    elif args.distribution == 'imbalance_scaling':
        distribution_params = f"{args.neg_scale}_{args.pos_scale}"
    elif args.distribution == 'gradient_imbalance_scaling':
        distribution_params = f"{args.gradient_name}_{args.neg_scale}_{args.pos_scale}"
    elif args.distribution == 'shuffle':
        distribution_params = ""
    elif args.distribution == 'bin':
        distribution_params = f"{args.bins}_{args.value_type}"
    if args.v_change_rate < 1:
        distribution_params += f"_{args.v_change_rate}"
    if args.distribution.startswith("imbalance"):
        assert args.sign_change_rate == 0, "imbalance value distribution does not support sign_change_rate"

    sign_prefix = f"{args.sign_change_method}_sign_{args.sign_change_rate}_{args.unchange_sign_rescale_rate}_{args.changed_sign_rescale_rate}_" if args.sign_change_rate > 0 else "sign_"
    args.model_save_path = os.path.join(
        args.output_dir,
        os.path.basename(args.finetuned_model_path),
        f"{sign_prefix}{args.distribution}_{distribution_params}_{args.seed}"
    )

    if os.path.exists(args.model_save_path) and len(os.listdir(args.model_save_path)) > 0:
        warnings.warn(f"{args.model_save_path} already exists and is not empty, skipping...")
        exit()

    model_pt = AutoModelForCausalLM.from_pretrained(
        args.pretrained_model_path,
        device_map="cpu",
        trust_remote_code=True,
    )
    model_ft = AutoModelForCausalLM.from_pretrained(
        args.finetuned_model_path,
        device_map="cpu",
        trust_remote_code=True,
    )

    new_param = {}

    total_value = 0
    total_value2 = 0
    total_param = 0

    if args.distribution == "gradient_imbalance_scaling":
        gradients = torch.load(args.gradient_path)
                
    with torch.no_grad():
        for param_name, pt_param in tqdm(model_pt.named_parameters()):
            ft_param = model_ft.state_dict()[param_name]
            delta = {param_name: ft_param - pt_param}
            new_value = delta[param_name].clone()
            if args.distribution == 'normal':
                mean = args.mean
                std = args.std
                if args.mean == 'auto' or args.std == 'auto':
                    normal_params = stats.norm.fit(delta[param_name].flatten().detach().numpy())
                    mean = args.mean if isinstance(args.mean, float) else normal_params[0]
                    std = args.std if isinstance(args.std, float) else normal_params[1]
                new_value = torch.normal(mean, std, delta[param_name].size())
            elif args.distribution == "uniform":
                high = np.percentile(delta[param_name].flatten().detach().numpy(), 90) if args.high == "auto" else args.high
                low = np.percentile(delta[param_name].flatten().detach().numpy(), 10) if args.low == "auto" else args.low
                new_value = torch.rand(delta[param_name].size()) * (high - low) + low
            elif args.distribution == "degenerate":
                if args.value == "auto_no_zero":
                    value = torch.abs(delta[param_name][delta[param_name] != 0]).mean().item()
                    new_value = torch.full_like(delta[param_name], 0)
                    new_value[delta[param_name] != 0] = value
                else:
                    value = args.value
                    if args.value == "auto":
                        value = torch.abs(delta[param_name]).mean().item()
                    new_value = torch.full_like(delta[param_name], value)

            elif args.distribution == "inverse_triangle":
                inverse = InverseTriangleDistribution(args.left, args.right, args.peak, args.left_area)
                new_value = torch.tensor(inverse.sample(delta[param_name].numel())).reshape(delta[param_name].size())
            elif args.distribution == "scaling":
                scale = args.scale
                new_value = delta[param_name] * scale
            elif args.distribution == "imbalance_scaling":
                new_value[new_value < 0] *= args.neg_scale
                new_value[new_value > 0] *= args.pos_scale
            elif args.distribution == "gradient_imbalance_scaling":
                delta_grad = delta[param_name] * gradients[param_name]
                new_value[delta_grad < 0] *= args.neg_scale
                new_value[delta_grad > 0] *= args.pos_scale
            elif args.distribution == "shuffle":
                new_value = delta[param_name]
                shuffled_x_flattened = new_value.view(-1)[torch.randperm(new_value.numel())]
                new_value = shuffled_x_flattened.view(new_value.size())

            elif args.distribution == "bin":
                new_value = torch.abs(delta[param_name])
                new_value_flat = new_value.view(-1).detach().numpy()
                bins = np.linspace(new_value_flat.min().item(), new_value_flat.max().item(), args.bins + 1)
                bin_indices = np.digitize(new_value_flat, bins)

                bin_values = []
                for i in range(1, args.bins + 2):
                    bin_indices_i = np.where(bin_indices == i)[0]
                    if args.value_type == "mean":
                        bin_values.append(new_value_flat[bin_indices_i].mean())
                    elif args.value_type == "median":
                        bin_values.append(np.median(new_value_flat[bin_indices_i]))
                # print(max(bin_indices))
                # print(bin_values)
                bin_values_array = np.array(bin_values)
                new_value_flat = torch.from_numpy(bin_values_array[bin_indices - 1])
                new_value = new_value_flat.view(new_value.size())


            if args.v_change_rate < 1:
                keep_value_num = int(delta[param_name].numel() * (1 - args.v_change_rate))
                keep_value_indices = np.random.choice(delta[param_name].numel(), keep_value_num, replace=False)
                new_value.view(-1)[keep_value_indices] = delta[param_name].view(-1)[keep_value_indices]

            new_sign = torch.sign(delta[param_name])
            sign_rescale_mask = torch.full_like(new_sign, 1)
            if args.sign_change_rate > 0:
                change_sign_num = int(delta[param_name].numel() * args.sign_change_rate)
                change_sign_indices = np.random.choice(delta[param_name].numel(), change_sign_num, replace=False)
                if args.sign_change_method == "inverse":
                    new_sign.view(-1)[change_sign_indices] *= -1
                elif args.sign_change_method == "random":
                    new_sign.view(-1)[change_sign_indices] = (torch.randint(0, 2, (change_sign_num,)) * 2 - 1).to(new_sign.dtype)
                elif args.sign_change_method == "inverse_neg":
                    negative_indices = np.where(delta[param_name].view(-1) < 0)[0]
                    if change_sign_num > len(negative_indices):
                        args.sign_change_rate = len(negative_indices) / delta[param_name].numel()
                        warnings.warn(f"{param_name}: change_sign_num {change_sign_num} is larger than the number of negative indices {len(negative_indices)}, change_sign_num is set to {len(negative_indices)} and sign_change_rate is set to {args.sign_change_rate}")
                        change_sign_num = len(negative_indices)
                    change_sign_indices = np.random.choice(negative_indices, change_sign_num, replace=False)
                    new_sign.view(-1)[change_sign_indices] *= -1
                elif args.sign_change_method == "inverse_neg_rescale_neg":
                    negative_indices = np.where(delta[param_name].view(-1) < 0)[0]
                    if change_sign_num > len(negative_indices):
                        warnings.warn(f"{param_name}: change_sign_num {change_sign_num} is larger than the number of negative indices {len(negative_indices)}, change_sign_num is set to {len(negative_indices)} and sign_change_rate is set to {args.sign_change_rate}")
                        change_sign_num = len(negative_indices)
                    change_sign_indices = np.random.choice(negative_indices, change_sign_num, replace=False)
                    new_sign.view(-1)[change_sign_indices] *= -1
                    # special rescale
                    unchanged_indices = np.setdiff1d(negative_indices, change_sign_indices)
                    rate = 2 * torch.sum(new_value.view(-1)[change_sign_indices]) / torch.sum(new_value.view(-1)[unchanged_indices]) + 1
                    new_value.view(-1)[unchanged_indices] *= min(max(rate, 0.1), 10)
                elif args.sign_change_method == "masking":
                    new_sign.view(-1)[change_sign_indices] = 0
                # elif args.sign_change_method == "shuffle":
                #     new
                elif args.sign_change_method == "magnitude_drop":
                    threshold = torch.kthvalue(delta[param_name].view(-1), change_sign_num).values
                    new_sign[delta[param_name] <= threshold] = 0

                if args.unchange_sign_rescale_rate == "auto":
                    if args.sign_change_method == "inverse" or args.sign_change_method == "inverse_neg":
                        args.unchange_sign_rescale_rate = (1 + args.sign_change_rate * args.changed_sign_rescale_rate) / (1 - args.sign_change_rate)
                    elif args.sign_change_method == "random" or args.sign_change_method == "masking":
                        args.unchange_sign_rescale_rate = 1 / (1 - args.sign_change_rate)

                sign_rescale_mask = torch.full_like(new_sign, args.unchange_sign_rescale_rate)
                sign_rescale_mask.view(-1)[change_sign_indices] = args.changed_sign_rescale_rate
                
            new_param = pt_param + torch.abs(new_value) * sign_rescale_mask * new_sign
            total_value += ((new_param - ft_param) * torch.sign(ft_param - pt_param)).sum()
            total_value2 += ((ft_param - new_param) * torch.sign(new_param - pt_param)).sum()
            total_param += pt_param.numel()

            pt_param.data.copy_(new_param)

            
    distance = total_value / total_param
    print(distance.item())
    distance2 = total_value2 / total_param
    print(distance2.item())

    with open("distance_history.txt", "a") as f:
        f.write(f"{args.model_save_path} {args.finetuned_model_path} {distance}\n")
        f.write(f"{args.finetuned_model_path} {args.model_save_path} {distance2}\n")

    tokenizer = AutoTokenizer.from_pretrained(args.finetuned_model_path, trust_remote_code=True)
    
    print(f"saving model at {args.model_save_path}...")
    tokenizer.save_pretrained(save_directory=args.model_save_path)
    model_pt.save_pretrained(save_directory=args.model_save_path)
        
