import torch

import transformers
from transformers import AutoTokenizer
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "mamba_peft/src/"))
from mamba_peft.src.peft import PeftModel
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import MambaForCausalLM

from lm_eval.api.model import LM
from lm_eval.models.huggingface import HFLM
from lm_eval.api.registry import register_model
from lm_eval.__main__ import cli_evaluate

from tqdm import tqdm
import copy
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D


@register_model("MambaPEFT")
class MambaEvalWrapper(HFLM):

    AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

    def __init__(self,
                 pretrained="state-spaces/mamba-130m-hf",
                 peft_weights=None,
                 max_length=2048,
                 batch_size=None,
                 device="cuda",
                 dtype=torch.float32,
                 trust_remote_code=False):
        self.peft_weights = peft_weights
        super().__init__(pretrained=pretrained,
                       tokenizer="EleutherAI/gpt-neox-20b",
                       max_length=max_length,
                       dtype=dtype,
                       trust_remote_code=trust_remote_code)

        self._batch_size = int(batch_size) if batch_size is not None else 64
        self._max_length = max_length
        self._device = torch.device(device)

        self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.vocab_size = self.tokenizer.vocab_size


    def _create_model(
        self,
        pretrained: str,
        dtype = "float32",
        # no `parallelize=True` options
        # no PEFT and quantization options
        # Mamba does not support arbitrary HF from_pretrained() args
        **kwargs,
    ) -> None:

        model = MambaForCausalLM.from_pretrained(
            pretrained,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            trust_remote_code=True
        )

        self._model = PeftModel.from_pretrained(
            model,
            self.peft_weights,
            torch_dtype=torch.float32,
        )
        print(model)
        self._model.config.use_cache = False # Not fully implemented yet
        self._model.float()
        self._model.to(self._device)

    @property
    def batch_size(self):
        return self._batch_size

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
        raise NotImplementedError()

    def visualize_tc_lim(self, sample_text="The quick brown fox jumps over the lazy dog.",
                         save_path="tc_lim_visualization.png"):
        """
        Generate visualization of TC-LIM effect on loss landscape
        """
        return plot_mamba_loss_landscape(
            model=self._model,
            tokenizer=self.tokenizer,
            sample_text=sample_text,
            save_path=save_path
        )


def plot_mamba_loss_landscape(model, tokenizer, sample_text="The quick brown fox jumps over the lazy dog.",
                              alpha_range=(-0.5, 0.5), beta_range=(-0.5, 0.5), n_points=15,
                              with_tc_lim=True, without_tc_lim=True, save_path="tc_lim_visualization.png"):
    """
    Visualizes the loss landscape for a Mamba model by perturbing weights along random directions.
    """
    # Save original weights
    orig_state_dict = copy.deepcopy(model.state_dict())

    # Create random directions
    direction1 = {}
    direction2 = {}

    # Only sample from non-embedding weights to better see curvature effects
    for name, param in model.named_parameters():
        if param.requires_grad and 'embed' not in name:
            direction1[name] = torch.randn_like(param)
            direction2[name] = torch.randn_like(param)

    # Normalize directions to have unit norm
    norm1 = torch.sqrt(sum(torch.sum(d * d) for d in direction1.values()))
    norm2 = torch.sqrt(sum(torch.sum(d * d) for d in direction2.values()))

    for name in direction1:
        direction1[name] /= norm1
        direction2[name] /= norm2

    # Create grid for evaluation
    alphas = np.linspace(alpha_range[0], alpha_range[1], n_points)
    betas = np.linspace(beta_range[0], beta_range[1], n_points)
    alphas_grid, betas_grid = np.meshgrid(alphas, betas)

    # Initialize loss grids
    loss_grid_with_tc_lim = np.zeros_like(alphas_grid) if with_tc_lim else None
    loss_grid_without_tc_lim = np.zeros_like(alphas_grid) if without_tc_lim else None

    # Tokenize input
    inputs = tokenizer(sample_text, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]

    # Compute loss function
    def compute_loss(use_tc_lim=True):
        # Enable or disable TC-LIM
        if hasattr(model, 'enable_tc_lim'):
            model.enable_tc_lim = use_tc_lim

        # Forward pass to compute loss
        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss

        return loss.item()

    # Sample loss landscape
    print("Sampling loss landscape...")
    for i, alpha in enumerate(tqdm(alphas)):
        for j, beta in enumerate(betas):
            # Update model with perturbed weights
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in direction1:
                        param.data = orig_state_dict[name] + alpha * direction1[name] + beta * direction2[name]

            # Compute loss with TC-LIM if requested
            if with_tc_lim:
                loss_grid_with_tc_lim[j, i] = compute_loss(use_tc_lim=True)

            # Compute loss without TC-LIM if requested
            if without_tc_lim:
                loss_grid_without_tc_lim[j, i] = compute_loss(use_tc_lim=False)

    # Restore original weights
    model.load_state_dict(orig_state_dict)

    # Create visualization
    fig = plt.figure(figsize=(18, 6))

    if with_tc_lim and without_tc_lim:
        ax1 = fig.add_subplot(131, projection='3d')
        ax2 = fig.add_subplot(132, projection='3d')
        ax3 = fig.add_subplot(133, projection='3d')

        # Plot without TC-LIM
        surf1 = ax1.plot_surface(alphas_grid, betas_grid, loss_grid_without_tc_lim,
                                 cmap=cm.coolwarm, linewidth=0, antialiased=True)
        ax1.set_title('Loss Landscape without TC-LIM')

        # Plot with TC-LIM
        surf2 = ax2.plot_surface(alphas_grid, betas_grid, loss_grid_with_tc_lim,
                                 cmap=cm.viridis, linewidth=0, antialiased=True)
        ax2.set_title('Loss Landscape with TC-LIM')

        # Plot the difference (regularization effect)
        diff = loss_grid_without_tc_lim - loss_grid_with_tc_lim
        surf3 = ax3.plot_surface(alphas_grid, betas_grid, diff,
                                 cmap=cm.plasma, linewidth=0, antialiased=True)
        ax3.set_title('Regularization Effect (Difference)')

        axes = [ax1, ax2, ax3]
    else:
        ax = fig.add_subplot(111, projection='3d')

        if with_tc_lim:
            surf = ax.plot_surface(alphas_grid, betas_grid, loss_grid_with_tc_lim,
                                   cmap=cm.viridis, linewidth=0, antialiased=True)
            ax.set_title('Loss Landscape with TC-LIM')
        else:
            surf = ax.plot_surface(alphas_grid, betas_grid, loss_grid_without_tc_lim,
                                   cmap=cm.coolwarm, linewidth=0, antialiased=True)
            ax.set_title('Loss Landscape without TC-LIM')

        axes = [ax]

    # Set labels for all axes
    for ax in axes:
        ax.set_xlabel('Direction 1')
        ax.set_ylabel('Direction 2')
        ax.set_zlabel('Loss')

        # Add curvature annotations
        max_height = ax.get_zlim()[1]
        ax.text(alpha_range[0], beta_range[0], max_height * 0.9,
                'High Curvature Region', color='red', fontsize=10)

    # Add title and explanation
    fig.suptitle('TC-LIM as Curvature-Aware Regularization for Mamba', fontsize=16)

    if with_tc_lim and without_tc_lim:
        fig.text(0.5, 0.01,
                 'TC-LIM provides adaptive regularization that responds to the curvature of the loss landscape.\n'
                 'Regions with high curvature (sharp peaks) receive stronger regularization,\n'
                 'resulting in a smoother optimization landscape.',
                 ha='center', fontsize=12, bbox=dict(facecolor='white', alpha=0.8))

    plt.tight_layout()

    # Save the figure if a path is provided
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {save_path}")

    return fig


# Add this to your main evaluation code
def main():
    import argparse
    parser = argparse.ArgumentParser(description='Visualize TC-LIM effect on loss landscape')
    parser.add_argument('--pretrained', type=str, default="state-spaces/mamba-130m-hf",
                        help='Pretrained model name or path')
    parser.add_argument('--sample', type=str, default="The quick brown fox jumps over the lazy dog.",
                        help='Sample text to evaluate loss on')
    parser.add_argument('--output', type=str, default="tc_lim_visualization.png",
                        help='Output path for visualization')
    parser.add_argument('--n_points', type=int, default=15,
                        help='Grid resolution for visualization')
    args = parser.parse_args()

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token_id = tokenizer.eos_token_id

    # Load base model
    print(f"Loading base model from {args.pretrained}...")
    model = MambaForCausalLM.from_pretrained(
        args.pretrained,
        load_in_8bit=False,
        torch_dtype=torch.float16,
        trust_remote_code=True
    )

    # Load PEFT model
    print(f"Loading PEFT weights from {args.peft_weights}...")
    model = PeftModel.from_pretrained(
        model,
        args.peft_weights,
        torch_dtype=torch.float32,
    )
    model.config.use_cache = False
    model.float()
    model.to("cuda")

    # Generate visualization
    print("Generating visualization...")
    plot_mamba_loss_landscape(
        model=model,
        tokenizer=tokenizer,
        sample_text=args.sample,
        n_points=args.n_points,
        save_path=args.output
    )

    print(f"Visualization saved to {args.output}")


if __name__ == "__main__":
    main()