import os
import torch
import contextlib
from typing import Dict, Union, List

from transformers import LlamaPreTrainedModel, OPTPreTrainedModel

from utils import HELPER_SUPPORT_MODEL_LIST, HELPER_SUPPORT_MODEL_TYPES

from hook import (
    add_collect_data_hook,
    remove_training_hook,
    add_inference_hook,
)
import utils


class Helper(contextlib.ContextDecorator):
    def __init__(self, model: HELPER_SUPPORT_MODEL_TYPES, compute_type, **kwargs):
        self.model = model
        self.device = model.device
        self.compute_type = compute_type
        self.hidden_size = kwargs["hidden_size"]
        self.intermediate_size = kwargs["intermediate_size"]
        self.training_data: Dict[str, Dict[str, List[torch.Tensor]]] = {}

        if not isinstance(model, HELPER_SUPPORT_MODEL_LIST):
            raise NotImplementedError("Unsupported model")

    def __enter__(self):
        self.model_last_layer = add_collect_data_hook(self.model, self.training_data, self.intermediate_size, self.hidden_size)

    def __exit__(self, exc_type, exc_val, exc_tb):
        remove_training_hook(self.model)

    def apply_to_model(self, use_trunc, use_bias, dump_dest, pruned_layer_idx_list, desired_rank_pref, 
                            model: HELPER_SUPPORT_MODEL_TYPES, optim_bias=True, wu='wu', wv='wv'):
        def infer_device() -> torch.device:
            if not torch.cuda.is_available():
                return torch.device("cpu")
            max_free_memory = -1
            best_device_index = -1
            for i in range(torch.cuda.device_count()):
                current_device = torch.device(f"cuda:{i}")
                torch.cuda.set_device(current_device)
                free_memory = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated()
                if free_memory > max_free_memory:
                    max_free_memory = free_memory
                    best_device_index = i
            if best_device_index == -1:
                return torch.device("cpu")
            else:
                return torch.device(f"cuda:{best_device_index}")
        
        def new_regis_mlp_attn_svd_llm_func(use_trunc, use_bias, pruned_layer_idx_list, model):
            from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
            for name, module in model.named_modules():
                if not isinstance(module, (LlamaMLP, LlamaAttention)):
                    continue
                layer_idx = int(name.split(".")[-2])
                if layer_idx not in pruned_layer_idx_list:
                    continue
                
                if isinstance(module, (LlamaMLP)):
                    suffix_list = ["gate_proj", "up_proj", "down_proj"]
                    for suffix in suffix_list:
                        if suffix not in desired_rank_pref[f'{layer_idx}'].keys() or desired_rank_pref[f'{layer_idx}'][suffix][0] == 0:
                            module.register_buffer(f'{suffix}_use', torch.Tensor([True]))
                            print(f"{layer_idx} {suffix} is not compressed.")
                        else:
                            module.register_buffer(f'{suffix}_use', torch.Tensor([False]))
                            u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.{wu}"), map_location=torch.device(infer_device()))
                            v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.{wv}"), map_location=torch.device(infer_device()))
                            print('mlp get u v:  ', name, suffix, u.shape, v.shape, u.device, v.device)
                            if suffix == "gate_proj":
                                module.register_buffer('gate_weight_U_top', v.t().to(torch.bfloat16))
                                module.register_buffer('gate_weight_SVh_top', u.t().to(torch.bfloat16))
                                module.gate_proj = None
                                del module.gate_proj
                                utils.clear_torch_cache()
                            elif suffix == "up_proj":
                                module.register_buffer('up_weight_U_top', v.t().to(torch.bfloat16))
                                module.register_buffer('up_weight_SVh_top', u.t().to(torch.bfloat16))
                                module.up_proj = None
                                del module.up_proj
                                utils.clear_torch_cache()
                            else:
                                module.register_buffer('down_weight_U_top', v.t().to(torch.bfloat16))
                                module.register_buffer('down_weight_SVh_top', u.t().to(torch.bfloat16))
                                module.down_proj = None
                                del module.down_proj
                                utils.clear_torch_cache()
                            u = s = v = None
                            del u, s, v
                            utils.clear_torch_cache()
                else:
                    suffix_list = ["q_proj", "k_proj", "o_proj", "v_proj"]
                    for suffix in suffix_list:
                        if suffix not in desired_rank_pref[f'{layer_idx}'].keys() or desired_rank_pref[f'{layer_idx}'][suffix][0] == 0:
                            module.register_buffer(f'{suffix}_use', torch.Tensor([True]))
                            print(f"{layer_idx} {suffix} is not compressed.")
                        else:
                            module.register_buffer(f'{suffix}_use', torch.Tensor([False]))
                            u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.{wu}"), map_location=torch.device(infer_device()))
                            v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.{wv}"), map_location=torch.device(infer_device()))
                            print('attn get u v: ', name, suffix, u.shape, v.shape, u.device, v.device)
                            if suffix == "q_proj":
                                module.register_buffer('q_weight_U_top', v.t().to(torch.bfloat16))
                                module.register_buffer('q_weight_SVh_top', u.t().to(torch.bfloat16))
                                module.q_proj = None
                                del module.q_proj
                                utils.clear_torch_cache()
                            elif suffix == "k_proj":
                                module.register_buffer('k_weight_U_top', v.t().to(torch.bfloat16))
                                module.register_buffer('k_weight_SVh_top', u.t().to(torch.bfloat16))
                                module.k_proj = None
                                del module.k_proj
                                utils.clear_torch_cache()
                            elif suffix == "o_proj":
                                module.register_buffer('o_weight_U_top', v.t().to(torch.bfloat16))
                                module.register_buffer('o_weight_SVh_top', u.t().to(torch.bfloat16))
                                module.o_proj = None
                                del module.o_proj
                                utils.clear_torch_cache()
                            elif suffix == "v_proj":
                                module.register_buffer('v_weight_U_top', v.t().to(torch.bfloat16))
                                module.register_buffer('v_weight_SVh_top', u.t().to(torch.bfloat16))
                                module.v_proj = None
                                del module.v_proj
                                utils.clear_torch_cache()
                            u = s = v = None
                            del u, s, v
                            utils.clear_torch_cache()
        
        if isinstance(model, LlamaPreTrainedModel):
            new_regis_mlp_attn_svd_llm_func(use_trunc, use_bias, pruned_layer_idx_list, model)
        
        def opt_new_regis_mlp_attn_svd_llm_func(use_trunc, use_bias, pruned_layer_idx_list, model):
            from torch import nn
            from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention
            
            compressed_cnt = 0
            compressed_name_list = []
            for name, module in model.named_modules():
                if not isinstance(module, (OPTDecoderLayer, OPTAttention)):
                    continue
                if isinstance(module, OPTDecoderLayer):
                    layer_idx = int(name.split(".")[-1])
                else:
                    layer_idx = int(name.split(".")[-2])
                if layer_idx not in pruned_layer_idx_list:
                    continue
                
                if isinstance(module, OPTDecoderLayer):
                    suffix_list = ["fc1", "fc2"]
                else:
                    suffix_list = ["q_proj", "k_proj", "out_proj", "v_proj"]
                    
                for suffix in suffix_list:
                    if not os.path.exists(os.path.join(dump_dest, f"{name}.{suffix}.{wu}")) or suffix not in desired_rank_pref[f'{layer_idx}'] or desired_rank_pref[f'{layer_idx}'][suffix][0] == 0:
                        module.register_buffer(f'{suffix}_use', torch.Tensor([True]))
                        print(f"{layer_idx} {suffix} is not compressed.")
                    else:
                        module.register_buffer(f'{suffix}_use', torch.Tensor([False]))
                        u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.{wu}"), map_location=torch.device(infer_device()))
                        v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.{wv}"), map_location=torch.device(infer_device()))
                        print('get u v:  ', name, suffix, u.shape, v.shape, u.device, v.device)
                        module.register_buffer(f'{suffix}_weight_U_top', v.t().to(torch.bfloat16))
                        module.register_buffer(f'{suffix}_weight_SVh_top', u.t().to(torch.bfloat16))
                        setattr(module, suffix, None)
                        delattr(module, suffix)
                        u = v = None
                        del u, v
                        utils.clear_torch_cache()
                        compressed_cnt += 1
                        compressed_name_list.append(f'{layer_idx}-{suffix}')
        
        if isinstance(model, OPTPreTrainedModel):
            opt_new_regis_mlp_attn_svd_llm_func(use_trunc, use_bias, pruned_layer_idx_list, model)
        
        add_inference_hook(optim_bias, use_trunc, use_bias, pruned_layer_idx_list, model)
        