import torch
from tqdm import tqdm
import pdb
from utils import train_utils
from utils import lowrank_layers


def replace_with_lowrank_linear(model, args, svd_info={}):
    """
    Replace all linear layers in a PyTorch model with low-rank layer using SVD. This is used 
    before training

    """
    full_name_dict = {module: name for name, module in model.named_modules()}
    linear_info = {}
    modules = [model]
    while len(modules) > 0:
        submodule = modules.pop()
        for name, raw_linear in submodule.named_children():
            if isinstance(raw_linear, torch.nn.Linear):
                full_name = full_name_dict[raw_linear]
                linear_info[raw_linear] = {
                    "father": submodule,
                    "name": name,
                    "full_name": full_name,
                }
            else:
                modules.append(raw_linear)

    for total_len, _ in enumerate(model.named_modules()):
        pass

    i = 0
    ignore_for = list(map(lambda x: x.strip(), args.only_compress.split(',')))

    for name, module in tqdm(model.named_modules(), total=total_len, desc='Replacing Linear with Low-Rank Layers', mininterval=5):
        if 'lm_head' in name:
            print('Ignored low-rank decomposition on logits layer')
        
        elif not any(item in name for item in ignore_for):
            print(f'Skipped {name} compression, keep_compress={args.only_compress}')

        elif module in linear_info:
            info = linear_info[module]

            if svd_info: 
                svd_vector = svd_info[info['full_name']]
            else: 
                svd_vector = None

            if args.layer_type == 'gumbel':
                new_module = lowrank_layers.LowrankLinear(module, args.init_frac, svd_vector, args.alpha, tau=args.start_tau,
                                            bias_init=args.bias_init, mask_eval_type=args.mask_eval_type,
                                            ).to(module.weight.dtype).to(module.weight.device)
                
            else:
                raise NotImplementedError(f"Unsupported layer_type {args.layer_type} in replace_linear_with_svd")
            
            setattr(info["father"], info["name"], new_module)

            del linear_info[module]
            del module
            torch.cuda.empty_cache()

            i += 1
            if i > 10 and args.debug:
                break

    torch.cuda.empty_cache()
    print('Replaced linear layers with low-rank layers.')


def update_layer_with_svd_output(current_layer, svd_vector, rank_pct, n_iter=2):
    """
    This is used to benchmark SVD/ASVD fixed rate baselines. 

    Given a linear layer, perform SVD using the desired rank, reconstruct the weight and 
    update the weight matrix in place. 
    
    """
    if not isinstance(current_layer, torch.nn.Linear):
        raise ValueError(f"Expected input to be of instance nn.Linear, got {type(current_layer)}")

    dtype = current_layer.weight.dtype
    rank = min(current_layer.weight.shape[1], current_layer.weight.shape[0])
    rank = int(rank * rank_pct)

    weight = current_layer.weight.float()
    if svd_vector is not None:
        svd_vector += 1e-6 # prevent div by 0 
        svd_vector = svd_vector.to(weight.device)
        weight = weight * svd_vector.unsqueeze(0)

    U, E, V = torch.svd_lowrank(weight, q=rank, niter=n_iter)

    if svd_vector is not None:
        V = V / svd_vector.unsqueeze(1)

    assert len(E.shape) == 1, 'Expected singular values to have only one dimension'

    # Precompute W_new for efficiency
    W_new = U * E.unsqueeze(0) @ V.T

    # Ensure that the shape of the new weights matches the original weight shape
    assert current_layer.weight.data.shape == W_new.shape, \
        f"Shape mismatch: original weights shape {current_layer.weight.data.shape}, new weights shape {W_new.shape}"

    # Update the weights of the current layer with W_new
    current_layer.weight.data = W_new.to(dtype)

def replace_linear_with_svd_naiive(model, args, svd_info={}, compression_dict={}):
    """
    This is used to benchmark SVD/ASVD fixed rate baselines. 

    Update linear layers in a model with the output of the SVD reconstruction. SVD decomposition is done based on args.param_ratio. 

    """
    full_name_dict = {module: name for name, module in model.named_modules()}
    linear_info = {}
    modules = [model]
    while len(modules) > 0:
        submodule = modules.pop()
        for name, raw_linear in submodule.named_children():
            if isinstance(raw_linear, torch.nn.Linear):
                full_name = full_name_dict[raw_linear]
                linear_info[raw_linear] = {
                    "father": submodule,
                    "name": name,
                    "full_name": full_name,
                }
            else:
                modules.append(raw_linear)

    for total_len, _ in enumerate(model.named_modules()):
        pass

    i = 0
    for name, module in tqdm(model.named_modules(), total=total_len, desc='Replacing Linear with Low-Rank Layers', mininterval=5):
        if 'lm_head' in name:
            print('Ignored low-rank decomposition on logits layer')
        
        elif module in linear_info:
            info = linear_info[module]

            if svd_info: 
                svd_vector = svd_info[info['full_name']]
            else: 
                svd_vector = None

            rank_pct = args.rank_pct

            if compression_dict: 
                layer_idx = name.split('.')[2] 
                found_name = name.split('.')[-1]

                assert layer_idx.isnumeric(), f"layer_idx: {layer_idx} is not numeric"
                
                if layer_idx in compression_dict and found_name in compression_dict[layer_idx]: 
                    rank = compression_dict[layer_idx][found_name]
                    m, n = module.in_features, module.out_features
                    new_module = lowrank_layers.SVDLinear(module, rank, svd_vector, args.alpha)
                    setattr(info["father"], info["name"], new_module)
                    print(f"Used compression_dict, {layer_idx}, {found_name}, rank: {rank}", new_module)
    
            elif args.param_ratio is not None:
                m, n = module.in_features, module.out_features
                rank = args.param_ratio * m * n / (m + n)
                rank_pct = rank/min(m, n)
                print(f"Using compression rate: {args.param_ratio}, m={m}, n={n}, hence r={rank}")
                update_layer_with_svd_output(module, svd_vector, rank_pct, n_iter=2)

            i += 1
            if i > 10 and args.debug:
                break

    print('Replaced linear layers with low-rank layers.')
    
def get_compression_layers(model):
    """
    Returns the model parameters that controls the singular value selection, utilized for compression loss
    """
    compression_params = []
    for name, param in model.named_parameters():
        if 'E_train' in name:
            compression_params.append(param)

    return compression_params

@torch.no_grad()
def get_compression_metadata(model):
    """
    Returns a dataset containing compresison metadata. For eg, retrieves the 
    
    """
    full_name_dict = {module: name for name, module in model.named_modules()}
    linear_info = {}
    modules = [model]
    while len(modules) > 0:
        submodule = modules.pop()
        for name, raw_linear in submodule.named_children():
            if hasattr(raw_linear, 'E_train'):
                full_name = full_name_dict[raw_linear]
                linear_info[raw_linear] = {
                    "name": name,
                    "full_name": full_name,
                }
            else:
                modules.append(raw_linear)

    compression_logs = [] 
    for module in linear_info: 
        layer_name = linear_info[module]['name']
        layer_idx = linear_info[module]['full_name'].split('.')[2]
        assert layer_idx.isnumeric(), f'layer_idx not numeric: {layer_idx}'

        weights = module.E_train.detach()

        mask = module.calculate_mask(is_training=False)
        r = int(round(mask.sum().item()))
        param_ratio = r * (module.in_features + module.out_features) / (module.in_features * module.out_features)

        compression_logs.append({'layer_idx': layer_idx, 'layer_name': layer_name, 'param_ratio': param_ratio, 'in_features': module.in_features, 'out_features': module.out_features, 
                                 'length': len(weights), 'weights': weights.tolist(), 'topk': r, 'mask': mask.tolist()}
                                 )

    return compression_logs


class CompressionCalculator:
    """
    Contains all utilis to get compression rate, calculate compression loss, singular value keep ratiom 
    
    """
    def __init__(self, model, total_params):
        # count num parameters without the lowrank layers
        self.params1 = 0 
        for name, param in model.named_parameters():
            has_matching_params = any(item in name for item in ['UE', 'V_t', 'E_train'])
            if has_matching_params:
                pass 
            else:
                self.params1 += param.numel()
        
        self.lowrank_layers = [] 
        for _, module in model.named_modules():
            if 'Lowrank' in str(module)[:7]:
                self.lowrank_layers.append(module) 

        self.total_params = total_params

    def get_compression(self):
        params2 = 0 
        params_wo_lowrank = 0 
        for module in self.lowrank_layers: 
            with torch.no_grad():
                rank = module.calculate_mask(is_training=False).sum().item()

            params_with_compression =  rank * (module.in_features + module.out_features)
            params_wo_compression = module.in_features * module.out_features

            # in reality, layer is not compressed when compression is huge
            if params_with_compression/params_wo_compression < 0.99:
                params2 += params_with_compression
            else:
                params2 += params_wo_compression

            params_wo_lowrank +=  params_wo_compression

        compression = (self.params1 + params2) / (self.params1 + params_wo_lowrank)
        return compression
    
    def get_compression(self):
        params2 = 0 
        params_wo_lowrank = 0 
        for module in self.lowrank_layers: 
            with torch.no_grad():
                rank = module.calculate_mask(is_training=False).sum().item()

            params_with_compression =  rank * (module.in_features + module.out_features)
            params_wo_compression = module.in_features * module.out_features

            # in reality, layer is not compressed when compression is huge
            if params_with_compression/params_wo_compression < 0.99:
                params2 += params_with_compression
            else:
                params2 += params_wo_compression

            params_wo_lowrank +=  params_wo_compression

        compression = (self.params1 + params2) / (self.params1 + params_wo_lowrank)
        return compression
    
    def get_compression_loss(self):
        loss = 0.
        for module in self.lowrank_layers: 
            comp_loss = module.E_train.mean()
            if isinstance(loss, torch.Tensor): comp_loss = comp_loss.to(loss.device)
            loss += comp_loss
        return loss/len(self.lowrank_layers)
    
    def get_sv_ratio(self):
        keep_ratio = 0. 
        for module in self.lowrank_layers:
            keep_ratio += module.calculate_mask(is_training=True).mean().item()

        return keep_ratio/len(self.lowrank_layers)

    def get_tv_loss(self):
        loss = 0. 
        for module in self.lowrank_layers:
            param_mask = module.calculate_mask(is_training=True)
            if isinstance(loss, torch.Tensor): param_mask = param_mask.to(loss.device)
            loss += train_utils.calculate_tv_loss(param_mask)

        return loss/len(self.lowrank_layers)


if __name__ == '__main__':
    from transformers import AutoModel
    import convert_model

    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    model = AutoModel.from_pretrained(model_name)
    class Args:
        def __init__(self):
            self.init_frac = 0.3
            self.ignore_compress=''
            self.layer_type = 'gumbel'
            self.only_compress = ''
            self.ignore_first_layer = False
            self.debug = True

    args = Args()

    replace_linear_with_svd_naiive(model, args)
    replacement_config = convert_model.replace_with_compressed_layer(model)
    print(f"\nConfig: {replacement_config}")
