import os
import random
import click
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaTokenizer,
    pipeline
)
import torch
import numpy as np

import utils
from helper import Helper

import lm_eval
from lm_eval.models.huggingface import HFLM
import functools
import copy
import json
import time
import tqdm

import logging
logger = logging.getLogger(__name__)

from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention, LlamaDecoderLayer
import matplotlib.pyplot as plt


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    

@click.command()
@click.option("-m", "--model", type=click.Path(file_okay=True), help="path to model file", default=None)
def cli(**kwargs):
    args = utils.EasyDict(**kwargs)
    print(args)

    model_path = args.model
    max_memory = "80000MB"
    max_memory = {i: max_memory for i in range(1)}
    
    model = AutoModelForCausalLM.from_pretrained(
            model_path,
            cache_dir=None,
            device_map="auto",
            quantization_config = None,
            torch_dtype=torch.bfloat16,
            attn_implementation="eager",
    )
    print("Model created")
    
    # Tokenizer
    if 'llama-3' in model_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            cache_dir=None,
            padding_side="right",
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            cache_dir=None,
            padding_side="right",
            use_fast=False, # Fast tokenizer giving issues.
            tokenizer_type='llama' if 'ama' in model_path else None, # Needed for HF name change
        )
    if tokenizer._pad_token is None:
        utils.smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=utils.DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )
    if 'ama' in model_path or isinstance(tokenizer, LlamaTokenizer):
        print('Adding special tokens.')
        tokenizer.add_special_tokens({
                "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
                "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
                "unk_token": tokenizer.convert_ids_to_tokens(
                    model.config.pad_token_id if model.config.pad_token_id and model.config.pad_token_id != -1 else tokenizer.pad_token_id
                ),
        })
    print("Tokenizer created")
    
    model.eval()
    setup_seed(2025)
    
    helper_params = {'intermediate_size': model.config.intermediate_size, 
                     'hidden_size': model.config.hidden_size}
    helper = Helper(model, torch.bfloat16, **helper_params)
    print("Helper created")
    
    ###### Compute the rank of each linear layer using Fisher Info ######
    # Calibration Fisher info data
    model_name = model.config._name_or_path.split("/")[-1]
    skipped_layers = []
    target_modules = ["q_proj", "k_proj", "o_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]
    calib_loader = torch.load('/data/7b/wikitext2_Llama-2-7b-hf_2048_4096_3.pt')
    utils.calib_fisher_info(model, target_modules, skipped_layers, calib_loader, model_name)
    
    if 'Llama-2-7b' in model_path:
        fisher_info_path = f"/data/7b/{model_name}_calib_fisher_info_wiki_2048_4096.pt"
    elif 'Llama-2-13b' in model_path:
        fisher_info_path = f"/data/13b/{model_name}_calib_fisher_info_wiki_2048_4096.pt"
    if 'llama-3-8B' in model_path:
        fisher_info_path = "/data/8b/llama-3-8B_calib_fisher_info_wiki_2048_4096.pt"
    all_fisher_info = torch.load(fisher_info_path, map_location="cpu")

    # Fisher info to compression related ratio
    if 'Llama-2-13b' in model_path:
        skipped_layers = [0, 1, 39]
    if 'llama-3-8B' in model_path or 'Llama-2-7b' in model_path:
        skipped_layers = [0, 1, 30, 31]
    fisher_proj_dict = {}
    fisher_mlp_sum = fisher_attn_sum = 0
    for name, fisher_info in all_fisher_info.items():
        layer_idx = int(name.split(".")[2])
        if layer_idx in skipped_layers:
            print(layer_idx, name, 'fisher sum skipped')
            continue
        fisher_proj = torch.mean(fisher_info).item()
        fisher_proj_dict[name] = fisher_proj
        if name.split(".")[-1] in ["gate_proj", "up_proj", "down_proj"]:
            fisher_mlp_sum += fisher_proj
        else:
            fisher_attn_sum += fisher_proj
    
    total_mlp_rank = total_attn_rank = 0
    m, n = model.config.intermediate_size, model.config.hidden_size
    fisher_mlp_ratio_dict = {}
    fisher_attn_ratio_dict = {}
    for name, fisher_proj in fisher_proj_dict.items():
        fisher_mlp_ratio_dict[name] = fisher_proj / fisher_mlp_sum
        fisher_attn_ratio_dict[name] = fisher_proj / fisher_attn_sum
        if name.split(".")[-1] in ["gate_proj", "up_proj", "down_proj"]:
            total_mlp_rank += int(((m * n) / (m + n)))
        elif name.split(".")[-1] in ["q_proj", "o_proj"]:
            total_attn_rank += int(((n * n) / (n + n)))
        elif name.split(".")[-1] in ["k_proj", "v_proj"]:
            # GQA architecture
            if 'llama-3-8B'in model_path:
                n_ = model.config.hidden_size / (model.config.num_attention_heads / model.config.num_key_value_heads)    # 1024 
            else:
                n_ = n
            total_attn_rank += int(((n * n_) / (n + n_)))
    
    target_rate = 0.7
    target_mlp_k = int(total_mlp_rank * target_rate)
    target_attn_k = int(total_attn_rank * target_rate)
    
    desired_ranks = {}
    for key, value in fisher_mlp_ratio_dict.items():
        layer_idx = key.split('.')[2]
        suffix = key.split('.')[-1]
        if suffix in ["gate_proj", "up_proj", "down_proj"]:
            desired_rank = int(target_mlp_k * value) // 32 * 32
            if desired_rank >= (m * n) / (m + n):
                print(key, desired_rank, 'full rank set 0')
                desired_rank = 0
                value = m * n
            if layer_idx not in desired_ranks.keys():
                desired_ranks[layer_idx] = {suffix: (desired_rank, value)}
            else:
                desired_ranks[layer_idx][suffix] = (desired_rank, value)
    for key, value in fisher_attn_ratio_dict.items():
        layer_idx = key.split('.')[2]
        suffix = key.split('.')[-1]
        if suffix in ["q_proj", "k_proj", "o_proj", "v_proj"]:
            desired_rank = int(target_attn_k * value) // 32 * 32
            if suffix in ["k_proj", "v_proj"]:
                # GQA architecture
                if 'llama-3-8B'in model_path:
                    n_ = model.config.hidden_size / (model.config.num_attention_heads / model.config.num_key_value_heads)    # 1024 
                else:
                    n_ = n
            else:
                n_ = n
            if desired_rank >= (n * n_) / (n + n_):
                print(key, desired_rank, 'full ranks set 0')
                desired_rank = 0
                value = n * n_
            if layer_idx not in desired_ranks.keys():
                desired_ranks[layer_idx] = {suffix: (desired_rank, value)}
            else:
                desired_ranks[layer_idx][suffix] = (desired_rank, value)
    
    ###### Uniform rank of each linear layer ######
    """
    model_name = model.config._name_or_path.split("/")[-1]
    if 'opt' in model_name:
        M = model.config.hidden_size
        N = model.config.ffn_dim
        layers_num = model.config.num_hidden_layers
        target_rate = 0.7
        mlp_rank = int(M * N * target_rate / (M + N))
        attn_rank = int(M * M * target_rate / (M + M))
        config = {"fc1": mlp_rank, "fc2": mlp_rank, 
                "q_proj": attn_rank,  "k_proj": attn_rank, 
                "v_proj": attn_rank, "out_proj": attn_rank}
        desired_ranks = {}
        for layer_idx in range(layers_num):
            for suffix in ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]:
                if f'{layer_idx}' not in desired_ranks.keys():
                    desired_ranks[f'{layer_idx}'] = {suffix: (config[suffix], None)}
                else:
                    desired_ranks[f'{layer_idx}'][suffix] = (config[suffix], None)
    else:
        M = model.config.hidden_size
        N = model.config.intermediate_size
        K = model.config.hidden_size / (model.config.num_attention_heads / model.config.num_key_value_heads)    # 1024
        layers_num = model.config.num_hidden_layers
        target_rate = 0.7
        mlp_rank = int(M * N * target_rate / (M + N))
        attn_rank = int(M * M * target_rate / (M + M))
        attn_rank_k_v = int(K * M * target_rate / (K + M))
        config = {"gate_proj": mlp_rank, "up_proj": mlp_rank, "down_proj": mlp_rank, 
                "q_proj": attn_rank,  "k_proj": attn_rank_k_v, 
                "v_proj": attn_rank_k_v, "o_proj": attn_rank}
        desired_ranks = {}
        for layer_idx in range(layers_num):
            for suffix in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]:
                if f'{layer_idx}' not in desired_ranks.keys():
                    desired_ranks[f'{layer_idx}'] = {suffix: (config[suffix], None)}
                else:
                    desired_ranks[f'{layer_idx}'][suffix] = (config[suffix], None)
    """
    
    # ###### data whiten & decomposition  ######
    if 'Llama-2-7b' in model_path:
        active_params = model_params = 7000842240
        dump_dest = '/data/7b/usv'
        hidden_states_dest = '/data/7b/hidden_states'
    elif 'Llama-2-13b' in model_path:
        active_params = model_params = 13343959040
        dump_dest = '/data/13b/usv'
        hidden_states_dest = '/data/13b/hidden_states'
    elif 'llama-3-8B' in model_path:
        active_params = model_params = 8030269440
        dump_dest = '/data/8b/usv_2048'
        hidden_states_dest = '/data/8b/hidden_states'
        
    # Calibration data
    wiki_train_dataset = utils.get_wikitext2(256, 3, 2048, tokenizer, 'wiki')
    with open('/data/wiki_256_2048.json', 'w') as json_file:
        json.dump(wiki_train_dataset, json_file, ensure_ascii=False, indent=4)
    with open('/data/wiki_256_2048.json', 'r') as file:
        prompts = json.load(file)
    
    # Compute XX^T
    generation_pipeline = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
    t_start_time = time.time()
    with helper:
        for text in prompts:
            prompt_token_count = len(generation_pipeline.tokenizer.encode(text, return_tensors="pt")[0])
            generation_pipeline(text, max_length=int(prompt_token_count), pad_token_id=tokenizer.eos_token_id, truncation=True)
    t_end_time = time.time()
    t_duration = t_end_time - t_start_time
    print(f"Collect training data costs avg: {t_duration/len(prompts): .5f} s, all: {t_duration/60: .2f} min, {t_duration: .5f} s. ")
    print('Collect training data Done')
    
    # Record XX^T
    for name, module in model.named_modules():
        suffix = name.split(".")[-1]
        if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]:
            continue
        layer_idx = int(name.split(".")[-3])
        raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}')
        torch.save(raw_scaling_diag_matrix, os.path.join('/data/13b/raw_scaling_diag_matrix', f"{name}.raw_scaling_diag_matrix"))
        print(name, 'raw scaling diag matrix saved')
    
    for filename in os.listdir(dump_dest):
        if filename.split('.')[-1] not in ['wu', 'wv', 'bias', 'trunc_w']:
            continue
        file_path = os.path.join(dump_dest, filename)
        if os.path.exists(file_path):
            os.remove(file_path)
            print(f'{file_path} deletion.')
    for name, module in model.named_modules():
        suffix = name.split(".")[-1]
        if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]:
            continue
        layer_idx = int(name.split(".")[-3])
        if f'{layer_idx}' not in desired_ranks:
            continue
        
        raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}').double().to(model.device)
        # raw_scaling_diag_matrix = torch.load(f'/data/13b/raw_scaling_diag_matrix/{name}.raw_scaling_diag_matrix').double().to(model.device)
        
        try:
            scaling_diag_matrix = torch.linalg.cholesky(raw_scaling_diag_matrix).float().to(model.device)
        except Exception as e:
            print(name, "Warning: eigen scaling_diag_matrix is not positive!")
            if torch.isnan(raw_scaling_diag_matrix).any():
                print("Warning: scaling_diag_matrix contains NaN!")
            elif torch.isinf(raw_scaling_diag_matrix).any():
                print("Warning: scaling_diag_matrix contains Inf!")
            if not torch.equal(raw_scaling_diag_matrix, raw_scaling_diag_matrix.T):
                print("Warning: scaling_diag_matrix is not a symmetric matrix!")
            eigenvalues = torch.linalg.eigvalsh(raw_scaling_diag_matrix)
            raw_scaling_diag_matrix += (- eigenvalues[0] + 1e-3) * torch.eye(raw_scaling_diag_matrix.shape[0]).to(model.device)
            scaling_diag_matrix = torch.linalg.cholesky(raw_scaling_diag_matrix).float().to(model.device)
        
        try:
            scaling_matrix_inv = torch.linalg.inv(scaling_diag_matrix)
        except Exception as e:
            print(name, "Warning: scaling_diag_matrix is not full rank!")
            scaling_diag_matrix += 1e-3 * torch.eye(scaling_diag_matrix.shape[0]).to(model.device)
            scaling_matrix_inv = torch.linalg.inv(scaling_diag_matrix).to(model.device)
        
        W = module.weight.float()
        
        if W.device != scaling_diag_matrix.device:
            scaling_diag_matrix = scaling_diag_matrix.to(W.device)
        W_scale = torch.matmul(W, scaling_diag_matrix)
        
        u, s, v = torch.linalg.svd(W_scale, full_matrices=False)
        
        torch.save(u, os.path.join(dump_dest, f"{name}.u"))
        torch.save(s, os.path.join(dump_dest, f"{name}.s"))
        print(name, 'u s v saved.', dump_dest)
        
        if v.device != scaling_matrix_inv.device:
            v = v.to(scaling_matrix_inv.device)
        v_inv = v @ scaling_matrix_inv
        torch.save(v_inv, os.path.join(dump_dest, f"{name}.v_inv"))
        print(name, 'v_inv saved.', dump_dest)
        
        u = torch.load(os.path.join(dump_dest, f"{name}.u"), map_location=torch.device('cuda'))
        s = torch.load(os.path.join(dump_dest, f"{name}.s"), map_location=torch.device('cuda'))
        v = torch.load(os.path.join(dump_dest, f"{name}.v_inv"), map_location=torch.device('cuda'))
        k = desired_ranks[f'{layer_idx}'][suffix][0]
        if k == 0:
            print(f"{name} skipped due to near full rank. \n")
            continue
        u, v = utils.get_uv(u, s, v, k)
        torch.save(u, os.path.join(dump_dest, f"{name}.wu"))
        torch.save(v, os.path.join(dump_dest, f"{name}.wv"))
        print(f"{name} {k} tok/last-k wu wv saved.")
        active_params -= module.weight.numel() - v.numel() - u.numel()
    print(f"Compression rate: {1 - active_params/model_params:.4f}")
    
    ###### Linear Layer Compen. ######
    # Prepare the dataset
    model_name = model.config._name_or_path.split("/")[-1]
    data_name = "wikitext2"
    batch_calib_loader = utils.get_calib_data_fisher(data_name, tokenizer, model_name, 4, 2048, 256, 3)
    
    # data collection
    def get_label_data(model, batch_calib_loader, hidden_states_dest, target_sample_cnt=256):
        hooks = []
        sample_cnt = 0
        for batch in batch_calib_loader:
            sample_data = {}
            def get_hidden_states_hook(name, layer_idx, module, inp, out):
                if name == 'model.norm':
                    sample_data[name] = out[0].cpu()
                else:
                    sample_data[layer_idx] = out[0].cpu()
            for name, module in model.named_modules():
                if not (isinstance(module, (LlamaDecoderLayer)) or name == 'model.norm'):
                    continue
                layer_idx = name.split('.')[-1]
                if name != 'model.norm' and layer_idx not in desired_ranks:
                    print(f"bias param: {name} {layer_idx} not in desired ranks")
                    continue
                handle_get_hidden_states_hook = module.register_forward_hook(functools.partial(
                            get_hidden_states_hook,
                            name,
                            layer_idx
                        ))
                hooks.append(handle_get_hidden_states_hook)
                
            batch_input_ids = torch.concat(batch["batch_input_ids"]).to(model.device)
            _ = model(batch_input_ids)
            for handle in hooks:
                handle.remove()
                
            torch.save(sample_data, os.path.join(hidden_states_dest, f"{sample_cnt}_hidden_staes.pth"))
            print(f"{sample_cnt} saved.")
            sample_cnt += 1
            if sample_cnt > +target_sample_cnt:
                print(sample_data.keys())
                break
        print('saved.')
    get_label_data(model, batch_calib_loader, hidden_states_dest, 2048)

    # construct bias parameters and its optimizer
    def set_model_param_bias(model, params_bias, params_init=False):
        params_bias_init = []
        param_bias_cnt = 0
        for name, module in model.named_modules():
            if not isinstance(module, (LlamaMLP, LlamaAttention)):
                continue
            layer_idx = name.split(".")[-2]
            if layer_idx not in desired_ranks:
                if params_init:
                    print(f"bias param: {name} {layer_idx} not in desired ranks")
                continue
            if isinstance(module, (LlamaMLP)):
                suffix_list = ["gate_proj", "up_proj", "down_proj"]
            else:
                suffix_list = ["q_proj", "k_proj", "v_proj", "o_proj"]
            for suffix in suffix_list:
                if suffix not in desired_ranks[layer_idx] or desired_ranks[layer_idx][suffix][0] == 0:
                    if params_init:
                        print(f"bias param: {name} {suffix} not in desired ranks layer")
                    continue
                if params_init:
                    bias_dim = model.config.intermediate_size if suffix in ["gate_proj", "up_proj"] else model.config.hidden_size
                    bias_dim = int(model.config.hidden_size / (model.config.num_attention_heads / model.config.num_key_value_heads)) if suffix in ["k_proj", "v_proj"] else bias_dim
                    params = torch.nn.Parameter(torch.zeros(bias_dim), requires_grad=True)
                    setattr(module, f"{suffix.split('_')[0]}_bias", params.to(torch.bfloat16))
                    params_bias_init.append(params)
                else:
                    setattr(module, f"{suffix.split('_')[0]}_bias", params_bias[param_bias_cnt].to(torch.bfloat16))
                    param_bias_cnt += 1
        if params_init:
            return params_bias_init
        print('set model params bais done')
    
    params_bias = set_model_param_bias(model, None, True)
    
    adw_optim = torch.optim.AdamW(params_bias, weight_decay=0.001, lr=0.005)
    Epoch_Num = 1
    ablation_sample_cnt = 2048
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adw_optim, T_max=Epoch_Num*ablation_sample_cnt)
    
    # add inference hook
    if 'Llama-2-7b' in model_path or 'llama-3-8B'in model_path:
        pruned_layer_idx_list = [i for i in range(2, 30)]
    elif 'Llama-2-13b' in model_path:
        pruned_layer_idx_list = [i for i in range(2, 38)]
        
    helper.apply_to_model(False, False, dump_dest, pruned_layer_idx_list, desired_ranks, model)
    
    utils.clear_torch_cache()
    print(f'Appling Done')
    
    model.eval()
    
    ppl = utils.eval_ppl(model, tokenizer)
    print('ppl (zero): ', ppl)
    
    # align
    best_ppl = 1e7
    best_params_bias = []
    total_loss_list = []
    hooks = []
    for epoch in range(Epoch_Num):
        sample_cnt = 0
        for batch in batch_calib_loader:
            pred_sample_data = {}
            def get_hidden_states_hook(name, layer_idx, module, inp, out):
                if name == 'model.norm':
                    pred_sample_data[name] = out[0].cpu()
                else:
                    pred_sample_data[layer_idx] = out[0].cpu()
            for name, module in model.named_modules():
                if not (isinstance(module, LlamaDecoderLayer) or name == 'model.norm'):
                    continue
                layer_idx  = name.split('.')[-1]
                if name != 'model.norm' and layer_idx not in desired_ranks:
                    continue
                handle_get_hidden_states_hook = module.register_forward_hook(functools.partial(
                            get_hidden_states_hook,
                            name,
                            layer_idx
                        ))
                hooks.append(handle_get_hidden_states_hook)
            
            batch_input_ids = torch.concat(batch["batch_input_ids"]).to(model.device)
            _ = model(batch_input_ids)
            for handle in hooks:
                handle.remove()
            
            label_sample_data = torch.load(os.path.join(hidden_states_dest, f"{sample_cnt}_hidden_staes.pth"))
            sample_cnt += 1
            assert pred_sample_data.keys() == label_sample_data.keys()
            batch_loss = 0
            for layer_idx in label_sample_data.keys():
                layer_loss = torch.nn.functional.mse_loss(pred_sample_data[layer_idx].float(), label_sample_data[layer_idx].float())
                batch_loss += layer_loss
            total_loss_list.append(batch_loss)
            print('batch total loss: ', batch_loss, 'batch cnt', sample_cnt)
            pred_sample_data = label_sample_data = None
            utils.clear_torch_cache()
            
            # backward
            adw_optim.zero_grad()
            batch_loss.backward()
            adw_optim.step()
            
            # update module bias param
            set_model_param_bias(model, params_bias, False)
            
            # evaluate ppl and save the best params
            if sample_cnt % 30 == 0:
                ppl = utils.eval_ppl(model, tokenizer)
                print('evaluaton ppl: ', ppl, ' batch cnt: ', sample_cnt)
                if ppl < best_ppl:
                    best_ppl = ppl
                    best_params_bias = copy.deepcopy(params_bias)
                    print('best params saved')
            
            lr_scheduler.step()
    
    model.eval()
    setup_seed(42)
    
    torch.save(best_params_bias, f'/data/{model_name}_{target_rate}_{Epoch_Num}_{ablation_sample_cnt}_params_bias_wiki.pth')
    print(f'Beast params saved')
    
    # Evaluate perplexity
    ppl = utils.eval_ppl(model, tokenizer)
    print('ppl: ', ppl)
    
    # Evaluate lm eval accuracy
    hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=8)
    task_names = ['piqa', 'hellaswag', 'boolq', 'winogrande', 'arc_easy', 'arc_challenge', 'openbookqa']
    results = lm_eval.simple_evaluate(hflm, tasks=task_names, num_fewshot=0, batch_size=8)[
        'results'
    ]
    print(results)
    metric_vals = {task: round(result.get(utils.TASK_METRIC_MAP[task]), 4) for task, result in results.items()}
    acc_avg = utils.calculate_avg_accuracy(task_names, results)
    metric_vals['average'] = round(acc_avg, 4)
    print(metric_vals)
    
    # Evaluate mmlu accuracy
    utils.eval_mmlu(model, tokenizer, 5, "data/mmlu-data")
    print('Eval MMLU Done \n')

    
if __name__ == "__main__":
    cli()
