import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import re
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import json

from .eval import eval_ppl
from .data import get_loaders

class LinearLSPUV(torch.nn.Module):
    """
    Custom linear layer containing U', V' matrices and sparse matrix S.
    """
    def __init__(self, in_features, out_features, rank, bias=True):
        super(LinearLSPUV, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank

        self.U_prime = torch.nn.Parameter(torch.empty((out_features, rank)))
        self.V_prime = torch.nn.Parameter(torch.empty((in_features, rank)))
        self.S = torch.nn.Parameter(torch.empty((out_features, in_features)), requires_grad=False)
        if bias:
            self.bias = torch.nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, input):
        L = torch.matmul(input, self.V_prime)  # (batch_size, rank)
        L = torch.matmul(L, self.U_prime.t())  # (batch_size, out_features)
        S = torch.matmul(input, self.S.t())  # (batch_size, out_features)
        output = L + S
        if self.bias is not None:
            output = output + self.bias
        return output

def prepare_calibration_data(tokenizer, nsamples, seed, seqlen, device):
    """
    Prepare calibration data from the C4 dataset.
    """
    from .data import get_loaders
    dataloader, _ = get_loaders(
        'c4', nsamples=nsamples, seed=seed, seqlen=seqlen, tokenizer=tokenizer
    )
    calibration_data = []
    for batch in dataloader:
        calibration_data.append(batch[0])
    calibration_data = torch.cat(calibration_data, dim=0).to(device)
    return calibration_data

def prune_rpca(args, model, tokenizer, device):
    # Do not move the model to device here

    # Layers to process
    layers_to_process = args.linear_layers.split(",")

    # Create directories to save results
    if args.save:
        os.makedirs(args.save, exist_ok=True)

    # Initialize variable to store evaluation metrics
    eval_metrics = {}

    # Collect all layer indices
    layer_indices = set()
    pattern = re.compile(r'\.(\d+)\.')
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            matches = pattern.findall(name)
            if matches:
                layer_indices.add(int(matches[0]))
    layer_indices = sorted(layer_indices)
    total_layers = len(layer_indices)

    # Prepare calibration data
    calibration_data = prepare_calibration_data(tokenizer, args.nsamples, args.seed, model.seqlen, device)

    # Get validation data loader
    val_loader, _ = get_loaders(
        'wikitext2', nsamples=args.val_samples, seed=args.seed, seqlen=model.seqlen, tokenizer=tokenizer
    )

    for idx, layer_idx in enumerate(layer_indices):
        print(f"\nProcessing layer {layer_idx} ({idx + 1}/{total_layers})")
        layer_modules = []
        # Collect modules for this layer
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear) and any(key in name for key in layers_to_process):
                if f'.{layer_idx}.' in name:
                    layer_modules.append((name, module))

        # Initialize s_params to store s_sigma and s_S for each module
        s_params = {}

        # Perform RPCA for all modules in the layer first
        for name, module in tqdm(layer_modules, desc=f"RPCA on modules in layer {layer_idx}"):
            weight = module.weight.data.clone()
            bias = module.bias
            weight_device = weight.device
            weight_dtype = weight.dtype

            # Move weight to device and convert to float32
            weight = weight.to(device).to(torch.float32)

            # Perform RPCA
            rpca = RPCA_gpu(
                weight,
                mu=args.rpca_mu,
                lambda_=args.rpca_lambda,
                tau_multiplier=args.tau_multiplier,
                use_nuclear_norm=args.use_nuclear_norm
            )

            # Perform RPCA
            L, S = rpca.fit(
                tol=args.rpca_tol,
                max_iter=args.rpca_max_iter,
                verbose=True
            )

            # Move L and S back to weight_device
            L = L.to(weight_device)
            S = S.to(weight_device)

            # Perform SVD on L
            with torch.no_grad():
                U, sigma, Vh = torch.linalg.svd(L.to(torch.float32), full_matrices=False)

            # Initialize s_sigma for singular values
            s_sigma = torch.full_like(sigma, args.init_singular_value_prob, device=weight_device)

            # Initialize s_S for S matrix
            s_S = torch.full_like(S, args.init_sparse_prob, device=weight_device)

            # Store parameters in s_params
            s_params[name] = {
                'U': U.to(weight_dtype),
                'sigma': sigma.to(weight_dtype),
                'Vh': Vh.to(weight_dtype),
                's_sigma': s_sigma,
                'S': S,
                's_S': s_S,
                'bias': bias,
                'weight_device': weight_device,
                'weight_dtype': weight_dtype,
                'm': weight.shape[0],
                'n': weight.shape[1]
            }

            # Clean up to save memory
            del weight, L
            torch.cuda.empty_cache()

        # Policy Gradient Optimization
        print(f"Starting Policy Gradient Optimization for layer {layer_idx}")
        optimize_masks(args, model, tokenizer, device, s_params, calibration_data)

        # After optimization, create new modules with pruned weights
        for name, module in layer_modules:
            params = s_params[name]
            U = params['U']
            sigma = params['sigma']
            Vh = params['Vh']
            s_sigma = params['s_sigma']
            S = params['S']
            s_S = params['s_S']
            bias = params['bias']
            weight_device = params['weight_device']
            weight_dtype = params['weight_dtype']

            # Apply the final mask for singular values
            m_sigma = (s_sigma >= 0.5).float()  # Threshold at 0.5
            sigma_masked = sigma * m_sigma

            # Determine rank after masking
            rank_L = int(m_sigma.sum().item())

            # Truncate U, sigma, Vh based on masked singular values
            if rank_L == 0:
                # If rank_L is zero, create zero matrices
                U_prime = torch.zeros(U.size(0), 1, device=weight_device, dtype=weight_dtype)
                V_prime = torch.zeros(Vh.size(1), 1, device=weight_device, dtype=weight_dtype)
                rank_L = 1
            else:
                U_selected = U[:, m_sigma.bool()]
                sigma_selected = sigma_masked[m_sigma.bool()]
                Vh_selected = Vh[m_sigma.bool(), :]

                # Compute U' and V'
                sqrt_sigma = torch.sqrt(sigma_selected + 1e-8)
                U_prime = U_selected * sqrt_sigma.unsqueeze(0)  # Shape: (out_features, rank_L)
                V_prime = (sqrt_sigma.unsqueeze(0) * Vh_selected).t()  # Shape: (in_features, rank_L)

            # Apply the final mask for S
            m_S = (s_S >= 0.5).float()
            S_masked = S * m_S

            # Create new module
            new_module = LinearLSPUV(module.in_features, module.out_features, rank_L, bias is not None)
            new_module.U_prime.data = U_prime.to(device=weight_device, dtype=weight_dtype)
            new_module.V_prime.data = V_prime.to(device=weight_device, dtype=weight_dtype)
            new_module.S.data = S_masked.to(device=weight_device, dtype=weight_dtype)
            if bias is not None:
                new_module.bias.data = bias.data.to(device=weight_device, dtype=weight_dtype)

            # Ensure S does not update during fine-tuning
            new_module.S.requires_grad = False

            # Move new_module to the correct device
            new_module.to(weight_device)

            # Replace module in the model
            parent_module = model
            name_parts = name.split('.')
            for n in name_parts[:-1]:
                parent_module = getattr(parent_module, n)
            setattr(parent_module, name_parts[-1], new_module)

            # Clean up to save memory
            del params['U'], params['Vh'], params['sigma'], params['S'], params['s_S']
            torch.cuda.empty_cache()

        # Evaluate the model after processing the layer
        if args.eval_after_each_layer:
            print(f"Evaluating the model after processing layer {layer_idx}...")
            model.eval()
            with torch.no_grad():
                ppl_test = eval_ppl(args, model, tokenizer, device)
            print(f"Perplexity after processing layer {layer_idx}: {ppl_test}")
            eval_metrics[f"layer_{layer_idx}"] = {
                'perplexity': float(ppl_test)
            }

    # Save evaluation metrics
    if args.save:
        eval_metrics_path = os.path.join(args.save, 'eval_metrics_after_each_layer.json')
        with open(eval_metrics_path, 'w') as f:
            json.dump(eval_metrics, f, indent=4)

def optimize_masks(args, model, tokenizer, device, s_params, calibration_data):
    """
    Optimize the Bernoulli parameters s_sigma and s_S using Policy Gradient Estimator.
    """
    # Initialize moving average baseline
    delta = 0.0
    beta = args.baseline_smoothing_factor

    # Collect s_sigma and s_S from all modules
    all_s_sigma = []
    all_s_S = []
    for params in s_params.values():
        all_s_sigma.append(params['s_sigma'])
        all_s_S.append(params['s_S'])

    # Convert to a single tensor for s_sigma and s_S
    s_sigma_tensor = torch.cat([s.view(-1) for s in all_s_sigma]).to(device)
    s_S_tensor = torch.cat([s.view(-1) for s in all_s_S]).to(device)

    # Combine s_sigma and s_S into one parameter tensor
    s_combined = torch.cat([s_sigma_tensor, s_S_tensor]).to(device)
    s_combined.requires_grad = True

    # Define optimizer for s_combined
    optimizer = torch.optim.Adam([s_combined], lr=args.pg_learning_rate)

    batch_size = args.val_samples  # You can adjust this as needed

    for pg_iter in range(1, args.pg_max_iter + 1):
        total_loss = 0.0
        total_params = 0
        total_retained_params = 0

        # Shuffle calibration data
        perm = torch.randperm(calibration_data.size(0))
        calibration_data = calibration_data[perm]

        # Split calibration data into batches
        for batch_idx in range(0, calibration_data.size(0), batch_size):
            inputs = calibration_data[batch_idx:batch_idx + batch_size].to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Sample masks and apply to modules
            idx = 0
            for name in s_params:
                params = s_params[name]
                s_sigma = params['s_sigma'].to(device)
                s_S = params['s_S'].to(device)

                num_sigma = s_sigma.numel()
                num_S = s_S.numel()

                # Get the corresponding slices from s_combined
                s_sigma_slice = s_combined[idx:idx + num_sigma]
                s_S_slice = s_combined[idx + num_sigma:idx + num_sigma + num_S]

                # Update indices
                idx += num_sigma + num_S

                # Reshape to original shape
                s_sigma = s_sigma_slice.view_as(s_sigma)
                s_S = s_S_slice.view_as(s_S)

                # Sample masks
                m_sigma = torch.bernoulli(s_sigma).to(device)
                m_S = torch.bernoulli(s_S).to(device)

                # Store m_sigma and m_S for gradient computation
                params['m_sigma'] = m_sigma
                params['m_S'] = m_S

                # Reconstruct L with masked singular values
                U = params['U'].to(device)
                sigma = params['sigma'].to(device)
                Vh = params['Vh'].to(device)

                sigma_masked = sigma * m_sigma

                # Compute U' and V'
                sqrt_sigma = torch.sqrt(sigma_masked + 1e-8)
                U_prime = U * sqrt_sigma.unsqueeze(0)  # Shape: (out_features, r)
                V_prime = (sqrt_sigma.unsqueeze(1) * Vh).t()  # Shape: (in_features, r)

                # Apply mask to S
                S = params['S'].to(device)
                S_masked = S * m_S

                # Update U', V', and S in params
                params['U_prime'] = U_prime
                params['V_prime'] = V_prime
                params['S_masked'] = S_masked

                # Count parameters
                m = params['m']
                n = params['n']
                bias_params = 0
                if params['bias'] is not None:
                    bias_params = params['bias'].numel()
                k = int(m_sigma.sum().item())  # Number of singular values retained

                # Total parameters before pruning
                total_params_module = m * n + bias_params

                # Retained parameters
                retained_params_L = (m + n) * k
                retained_params_S = (m_S != 0).sum().item()
                total_retained_params_module = retained_params_L + retained_params_S + bias_params

                # Update counts
                total_params += total_params_module
                total_retained_params += total_retained_params_module

                # Clean up to save memory
                del U, Vh, sigma, S
                torch.cuda.empty_cache()

            # Replace modules in the model outside the batch loop to avoid OOM
            with torch.no_grad():
                for name, module in model.named_modules():
                    if name in s_params:
                        params = s_params[name]
                        U_prime = params['U_prime']
                        V_prime = params['V_prime']
                        S_masked = params['S_masked']
                        bias = params['bias']
                        weight_device = params['weight_device']
                        weight_dtype = params['weight_dtype']

                        # Create new module
                        rank_L = U_prime.shape[1]
                        new_module = LinearLSPUV(module.in_features, module.out_features, rank_L, bias is not None)
                        new_module.U_prime.data = U_prime.to(device=weight_device, dtype=weight_dtype)
                        new_module.V_prime.data = V_prime.to(device=weight_device, dtype=weight_dtype)
                        new_module.S.data = S_masked.to(device=weight_device, dtype=weight_dtype)
                        if bias is not None:
                            new_module.bias.data = bias.data.to(device=weight_device, dtype=weight_dtype)

                        # Ensure S does not update during optimization
                        new_module.S.requires_grad = False

                        # Move new_module to the correct device
                        new_module.to(weight_device)

                        # Replace module in the model
                        parent_module = model
                        name_parts = name.split('.')
                        for n in name_parts[:-1]:
                            parent_module = getattr(parent_module, n)
                        setattr(parent_module, name_parts[-1], new_module)

            # Compute loss
            model.eval()
            with torch.no_grad():
                outputs = model(inputs)
                logits = outputs.logits
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = inputs[..., 1:].contiguous()
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            total_loss += loss.item()

            # Update moving average baseline
            delta = beta * delta + (1 - beta) * loss.item()

            # Compute gradients w.r.t s_combined using policy gradient estimator
            idx = 0
            grad_list = []
            for name in s_params:
                params = s_params[name]
                s_sigma = params['s_sigma']
                s_S = params['s_S']
                m_sigma = params['m_sigma']
                m_S = params['m_S']

                num_sigma = s_sigma.numel()
                num_S = s_S.numel()

                # Get slices
                s_sigma_slice = s_combined[idx:idx + num_sigma]
                s_S_slice = s_combined[idx + num_sigma:idx + num_sigma + num_S]

                # Update indices
                idx += num_sigma + num_S

                # Compute gradients
                grad_sigma = (loss.item() - delta) * (m_sigma - s_sigma_slice) / (s_sigma_slice * (1 - s_sigma_slice) + 1e-8)
                grad_S = (loss.item() - delta) * (m_S - s_S_slice) / (s_S_slice * (1 - s_S_slice) + 1e-8)

                grad_list.append(torch.cat([grad_sigma.view(-1), grad_S.view(-1)]))

            # Concatenate gradients
            grad_combined = torch.cat(grad_list).to(device)

            # Backward pass
            s_combined.grad = grad_combined
            optimizer.step()

            # Clamp s_combined to [0, 1]
            s_combined.data.clamp_(0.0, 1.0)

            # Update s_params with new s_sigma and s_S
            idx = 0
            for params in s_params.values():
                s_sigma = params['s_sigma']
                s_S = params['s_S']

                num_sigma = s_sigma.numel()
                num_S = s_S.numel()

                # Get slices
                s_sigma_slice = s_combined[idx:idx + num_sigma]
                s_S_slice = s_combined[idx + num_sigma:idx + num_sigma + num_S]

                # Update indices
                idx += num_sigma + num_S

                # Update s_sigma and s_S
                params['s_sigma'] = s_sigma_slice.view_as(s_sigma).detach()
                params['s_S'] = s_S_slice.view_as(s_S).detach()

        # Compute compression rate
        compression_rate = 1.0 - (total_retained_params / total_params)

        if pg_iter % args.pg_print_interval == 0:
            print(f"Iteration {pg_iter}/{args.pg_max_iter}, Loss: {total_loss:.4f}, Compression Rate: {compression_rate:.4f}")

        # Check if compression rate is within tolerance
        if abs(compression_rate - args.compression_target) <= args.compression_tolerance:
            print("Desired compression rate achieved.")
            break

    # After optimization, update s_params with final s_sigma and s_S
    for params in s_params.values():
        params['s_sigma'] = params['s_sigma'].clamp(0.0, 1.0)
        params['s_S'] = params['s_S'].clamp(0.0, 1.0)

class RPCA_gpu:
    """ Low-rank and sparse matrix decomposition via RPCA with CUDA capabilities """
    def __init__(self, D, mu=None, lambda_=None, tau_multiplier=1.0, use_nuclear_norm=True):
        self.D = D
        self.device = D.device
        self.use_nuclear_norm = use_nuclear_norm
        self.norm_D = torch.norm(D, p='fro').item()
        self.m, self.n = D.shape

        # Parameter initialization
        # Use standard mu and lambda if not provided
        self.mu = mu or (np.prod(self.D.shape) / (4 * self.norm_p(self.D, 2))).item()
        self.mu_inv = 1 / self.mu
        self.lambda_ = lambda_ or 1 / np.sqrt(np.max(self.D.shape))
        self.tau_multiplier = tau_multiplier  # Multiplier for tau

    @staticmethod
    def norm_p(M, p):
        return torch.sum(torch.pow(M.abs(), p))

    @staticmethod
    def shrink(M, tau):
        return torch.sign(M) * F.relu(torch.abs(M) - tau)

    def svd_threshold(self, M, tau):
        tau *= self.tau_multiplier  # Apply tau multiplier
        U, s, Vh = torch.linalg.svd(M, full_matrices=False)
        s_threshold = F.relu(s - tau)
        return U @ torch.diag(s_threshold) @ Vh

    def fit(self, tol=1e-7, max_iter=1000, verbose=False):
        i = 0
        S = torch.zeros_like(self.D, device=self.device)
        Y = torch.zeros_like(self.D, device=self.device)
        L = torch.zeros_like(self.D, device=self.device)
        mu_inv = self.mu_inv
        _tol = tol * self.norm_D
        converged = False

        # Track best L and S with minimum error
        best_L = None
        best_S = None
        min_error = float('inf')

        for i in range(max_iter):
            # Update L
            if self.use_nuclear_norm:
                L = self.svd_threshold(self.D - S + mu_inv * Y, mu_inv)
            else:
                L = self.D - S + mu_inv * Y

            # Update S
            S = self.shrink(self.D - L + mu_inv * Y, mu_inv * self.lambda_ * self.tau_multiplier)  # Apply tau_multiplier

            # Update Y
            Z = self.D - L - S
            Y = Y + self.mu * Z

            # Check convergence
            err = torch.norm(Z, p='fro').item() / self.norm_D

            # Track the best L and S with the minimum error so far
            if err < min_error:
                min_error = err
                best_L = L.clone()
                best_S = S.clone()

            if verbose and ((i % 50) == 0 or i == 1 or i >= max_iter - 1 or err <= tol):
                rank_L = torch.linalg.matrix_rank(L).item()
                print(f'Iteration: {i+1}; Error: {err:0.4e}; Rank of L: {rank_L}')
                    
            if err < tol:
                converged = True
                break

        if not converged:
            print("RPCA did not converge within the maximum number of iterations.")
        
        # Return the best L and S with the minimum observed error
        self.L = best_L
        self.S = best_S
        return self.L, self.S
