import re
import torch
import subprocess
from ..config import cfg

MULTIGPUS_MODEL_NAME_LIST = ["llama-2-70b"]

def list_available_gpus():
    cfg["custom_cuda_streams"] = {}
    cfg["default_cuda_streams"] = {}
    # Check if CUDA is available
    if torch.cuda.is_available():
        # Get the number of GPUs available
        num_gpus = torch.cuda.device_count()
        print(f"Number of CUDA Devices: {num_gpus}")
        
        for gpu_id in range(num_gpus):
            # Set the current device to the GPU
            stream = torch.cuda.Stream(device=gpu_id)
            # Store the stream for the corresponding GPU
            cfg["custom_cuda_streams"][gpu_id] = stream
            default_stream = torch.cuda.default_stream(device=gpu_id)
            cfg["default_cuda_streams"][gpu_id] = default_stream
    else:
        print("CUDA is not available. No GPU detected.")

def process_control():
    cfg["cudatoolkit_version"] = float(torch.version.cuda)
    cfg["cudnn_version"] = float(torch.backends.cudnn.version())

    try:
        result = subprocess.run(["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"], capture_output=True, text=True)
        gpu_name = result.stdout.strip()
        print(f"GPU Name: {gpu_name}")
    except Exception as e:
        print(f"An error occurred: {e}")
    cfg["gpu_name"] = gpu_name
    list_available_gpus()
    cfg["data_type"] = torch.float16
    cfg["data_type_max"] = torch.finfo(cfg["data_type"]).max
    cfg["data_type_min"] = torch.finfo(cfg["data_type"]).min

    # This can be implemented dynamically in each layer
    # tc stands for tensor core
    if cfg["data_type"] == torch.float16 or cfg["data_type"] == torch.float32:
        if cfg["cudatoolkit_version"] >= 11 and cfg["cudnn_version"] >= 7630:
            if gpu_name == 'NVIDIA A100-SXM4-40GB':
                cfg["tc_multiple"] = 64
            elif gpu_name == 'NVIDIA GeForce RTX 4090':
                cfg["tc_multiple"] = 8
            else:
                cfg["tc_multiple"] = 64
        else:
            cfg["tc_multiple"] = 64

    cfg["model_name"] = cfg["control"]["model_name"]
    model_name = cfg["model_name"]
    cfg["batch_size"] = int(cfg["control"]["batch_size"])
    cfg["max_seq_len"] = int(cfg["control"]["max_seq_len"])
    cfg["prune_metric"] = cfg["control"]["prune_metric"]
    cfg["prune_method"] = cfg["control"]["prune_method"]
    cfg["prune_ratio"] = cfg["control"]["prune_ratio"]

    if "default" in cfg["prune_method"]:
        if "flap" in cfg["prune_method"]:
            cfg["prune_method"] += "-calib"
            cfg["prune_method"] += "-flapratio"
            cfg["prune_method"] += "-bias"
        elif "wandasp" in cfg["prune_method"]:
            cfg["prune_method"] += "-calib"
        elif "probe" in cfg["prune_method"]:
            cfg["prune_method"] += "-calib"
            cfg["prune_method"] += "-ema"
            cfg["prune_method"] += "-respick"

    cfg["mode"] = cfg["control"]["mode"]
          
    cfg["ema_momentum"] = 0.99
    if "ema" in cfg["prune_method"]:
        match = re.search(r"ema(\d+\.\d+)", cfg["prune_method"])
        if match:
            # Convert the matched string to a float
            float_value = float(match.group(1))
            cfg["ema_momentum"] = float_value  
        else:
            float_value = None        
    
    cfg["cust_tgt_modules"] = cfg["control"]["cust_tgt_modules"].split('+')
    if "llama" in cfg["model_name"] and cfg["cust_tgt_modules"] != ["default"]:
        cfg["cust_tgt_modules"] = [module.replace("-", "_") for module in cfg["cust_tgt_modules"]]
    elif "opt" in cfg["model_name"] and cfg["cust_tgt_modules"] != ["default"]:
        cfg["cust_tgt_modules"] = [module.replace("-", "_") for module in cfg["cust_tgt_modules"]]
    elif cfg["cust_tgt_modules"] == ["default"]:
        if "llama" in cfg["model_name"]:
            cfg["cust_tgt_modules"] = TRANSFORMERS_MODELS_TO_ERI_TARGET_MODULES_MAPPING["llama"]
        elif "opt" in cfg["model_name"]:
            cfg["cust_tgt_modules"] = TRANSFORMERS_MODELS_TO_ERI_TARGET_MODULES_MAPPING["opt"]
        else:
            raise ValueError('Not valid model name')

    cfg["calibration_stage"] = False
    # default skip 3 layers
    cfg["skip_layers"] = [0,1,2]
    if "skip" in cfg["prune_method"]:
        match = re.findall(r'skip(-?\d+(-\d+)*)', cfg["prune_method"])
        print('match: ', match)
        if match:
            # Convert found strings to integers
            numbers_str = match[0][0].split("-")
            numbers = [int(num) for num in numbers_str if num]  
            even_index_numbers = numbers[0::2]
            odd_index_numbers = numbers[1::2]     
            # Generate the list using range
            skip_layers = []
            for i in range(len(odd_index_numbers)):
                start = even_index_numbers[i]
                end = odd_index_numbers[i]
                skip_layers.extend(range(start, end + 1))
            cfg["skip_layers"] = skip_layers

    cfg["cur_batch_index"] = -1
    
    cfg["qk_prune_way"] = "whole"
    cfg["vo_prune_way"] = "whole"
    
    if model_name not in cfg:
        cfg[model_name] = {}
    cfg[model_name]["shuffle"] = {"train": False, "test": False}

    cfg["logger_detailed_info"] = False
    cfg["onlyprobe"] = False
    # cfg["onlyprobeinfo"] = True
    cfg["onlyprobeinfo"] = False
    cfg["test_speed"] = False

    cfg["asyncintra_on_diff_gpu"] = False
    cfg["pad_tokens"] = None
    print("cfg: ", cfg, flush=True)
    return


def make_data_name():
    data_name_list = cfg["control"]["data_name"].split("-")
    if len(data_name_list) == 2:
        cfg["data_name"], cfg["subset_name"] = data_name_list
    else:
        cfg["data_name"] = data_name_list[0]
        cfg["subset_name"] = "none"
    if cfg["task_name"] in ["clm", "csr"]:
        data_name_dict = {
            "c4": {"data_name": "c4",
                          "subset_name_dict": {"none": {"subset_name": None,
                                                   "text_column": None,
                                                   "label_column": None}
                                           }                       
                         },
            "wikitext": {"data_name": "wikitext",
                          "subset_name_dict": {"2v1": {"subset_name": "wikitext-2-raw-v1",
                                                   "text_column": ["text"],
                                                   "label_column": None}
                                           }                       
                         },
            "boolq": {"data_name": 'google/boolq',
                    "subset_name_dict": {
                        "none": {"subset_name": None,
                              "text_column": ["hardcode"],
                              "label_column": "hardcode"}
                        },
            },  
            "piqa": {"data_name": "piqa",
                    "subset_name_dict": {
                        "none": {"subset_name": None,
                              "text_column": ["hardcode"],
                              "label_column": "hardcode"}
                        },
            },          
            "siqa": {"data_name": "social_i_qa",
                    "subset_name_dict": {
                        "none": {"subset_name": None,
                              "text_column": ["hardcode"],
                              "label_column": "hardcode"}
                        },
            },         
           "arc": {"data_name": "ai2_arc",
                    "subset_name_dict": {
                        "e": {"subset_name": "ARC-Easy",
                              "text_column": ["hardcode"],
                             "label_column": "hardcode"},   
                        "c": {"subset_name": "ARC-Challenge",
                              "text_column": ["hardcode"],
                              "label_column": "hardcode"}
                        },                        
            },
            "hellaswag": {"data_name": 'Rowan/hellaswag',
                    "subset_name_dict": {
                        "none": {"subset_name": None,
                              "text_column": "hardcode",
                              "label_column": "hardcode"}, 
                        },                        
            },
            "winogrande": {"data_name": "winogrande",
                    "subset_name_dict": {
                        "none": {"subset_name": "winogrande_debiased",
                              "text_column": "hardcode",
                              "label_column": "hardcode"}, 
                        },                        
            },
            "obqa": {"data_name": "openbookqa",
                    "subset_name_dict": {
                        "none": {"subset_name": "main",
                              "text_column": ["hardcode"],
                              "label_column": "hardcode"},    
                        },                        
            },
        }
        cfg["hf_data_name"] = data_name_dict[cfg["data_name"]]["data_name"]
        cfg["hf_subset_name"] = data_name_dict[cfg["data_name"]]["subset_name_dict"][
            cfg["subset_name"]]["subset_name"]
        cfg["text_column"] = data_name_dict[cfg["data_name"]]["subset_name_dict"][
            cfg["subset_name"]]["text_column"]
        cfg["label_column"] = data_name_dict[cfg["data_name"]]["subset_name_dict"][
            cfg["subset_name"]]["label_column"]
    return


TRANSFORMERS_MODELS_TO_ERI_TARGET_MODULES_MAPPING = {
    "opt": ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
    "llama": ["q_proj", "v_proj", "o_proj", "k_proj", "gate_proj", "up_proj", "down_proj"],
    # "llama": ["gate_proj", "up_proj", "down_proj"],
}




# gpt2 layer
'''
key:  transformer.h.3 <class 'transformers.models.gpt2.modeling_gpt2.GPT2Block'>
key:  transformer.h.3.ln_1 <class 'torch.nn.modules.normalization.LayerNorm'>
key:  transformer.h.3.attn <class 'transformers.models.gpt2.modeling_gpt2.GPT2Attention'>
key:  transformer.h.3.attn.c_attn <class 'transformers.pytorch_utils.Conv1D'>
key:  transformer.h.3.attn.c_proj <class 'transformers.pytorch_utils.Conv1D'>
key:  transformer.h.3.attn.attn_dropout <class 'torch.nn.modules.dropout.Dropout'>
key:  transformer.h.3.attn.resid_dropout <class 'torch.nn.modules.dropout.Dropout'>
key:  transformer.h.3.ln_2 <class 'torch.nn.modules.normalization.LayerNorm'>
key:  transformer.h.3.mlp <class 'transformers.models.gpt2.modeling_gpt2.GPT2MLP'>
key:  transformer.h.3.mlp.c_fc <class 'transformers.pytorch_utils.Conv1D'>
key:  transformer.h.3.mlp.c_proj <class 'transformers.pytorch_utils.Conv1D'>
key:  transformer.h.3.mlp.act <class 'transformers.activations.NewGELUActivation'>
key:  transformer.h.3.mlp.dropout <class 'torch.nn.modules.dropout.Dropout'>
'''


# opt 1.3b layer

'''
125M350M1.3B2.7B6.7B13B30B66B175B
selected: 6.7B13B30B66B
key:  model.decoder.layers.0 <class 'transformers.models.opt.modeling_opt.OPTDecoderLayer'>
key:  model.decoder.layers.0.self_attn <class 'transformers.models.opt.modeling_opt.OPTAttention'>
key:  model.decoder.layers.0.self_attn.k_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.decoder.layers.0.self_attn.v_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.decoder.layers.0.self_attn.q_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.decoder.layers.0.self_attn.out_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.decoder.layers.0.activation_fn <class 'torch.nn.modules.activation.ReLU'>
key:  model.decoder.layers.0.self_attn_layer_norm <class 'torch.nn.modules.normalization.LayerNorm'>
key:  model.decoder.layers.0.fc1 <class 'torch.nn.modules.linear.Linear'>
key:  model.decoder.layers.0.fc2 <class 'torch.nn.modules.linear.Linear'>
key:  model.decoder.layers.0.final_layer_norm <class 'torch.nn.modules.normalization.LayerNorm'>
'''

# llama-2-7b layer
'''
7b, 13b, 65b
key:  model.layers.0 <class 'transformers.models.llama.modeling_llama.LlamaDecoderLayer'>
key:  model.layers.0.self_attn <class 'transformers.models.llama.modeling_llama.LlamaAttention'>
key:  model.layers.0.self_attn.q_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.layers.0.self_attn.k_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.layers.0.self_attn.v_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.layers.0.self_attn.o_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.layers.0.self_attn.rotary_emb <class 'transformers.models.llama.modeling_llama.LlamaRotaryEmbedding'>
key:  model.layers.0.mlp <class 'transformers.models.llama.modeling_llama.LlamaMLP'>
key:  model.layers.0.mlp.gate_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.layers.0.mlp.up_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.layers.0.mlp.down_proj <class 'torch.nn.modules.linear.Linear'>
key:  model.layers.0.mlp.act_fn <class 'transformers.activations.SiLUActivation'>
key:  model.layers.0.input_layernorm <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>
key:  model.layers.0.post_attention_layernorm <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>
'''
