import torch
import transformers
from transformers import AutoTokenizer
import sys
import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import io
import json
import contextlib

sys.path.append(os.path.join(os.getcwd(), "mamba_peft/src/"))
from mamba_peft.src.peft import PeftModel
from transformers import MambaForCausalLM

from lm_eval.__main__ import cli_evaluate
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM


@register_model("MambaPEFT")
class MambaEvalWrapper(HFLM):
    AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

    def __init__(self,
                 pretrained="state-spaces/mamba-130m-hf",
                 peft_weights=None,
                 max_length=2048,
                 batch_size=None,
                 device="cuda",
                 dtype=torch.float32,
                 trust_remote_code=False):
        self.peft_weights = peft_weights
        super().__init__(pretrained=pretrained,
                         tokenizer="EleutherAI/gpt-neox-20b",
                         max_length=max_length,
                         dtype=dtype,
                         trust_remote_code=trust_remote_code)

        self._batch_size = int(batch_size) if batch_size is not None else 64
        self._max_length = max_length
        self._device = torch.device(device)

        self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.vocab_size = self.tokenizer.vocab_size

    def _create_model(self, pretrained: str, dtype="float32", **kwargs) -> None:
        model = MambaForCausalLM.from_pretrained(
            pretrained,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            trust_remote_code=True
        )

        self._model = PeftModel.from_pretrained(
            model,
            self.peft_weights,
            torch_dtype=torch.float32,
        )
        print(model)
        self._model.config.use_cache = False
        self._model.float()
        self._model.to(self._device)

    @property
    def batch_size(self):
        return self._batch_size

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
        raise NotImplementedError()


# === Loss Landscape Helpers ===

def get_flattened_params(model):
    return torch.cat([p.data.view(-1) for p in model.parameters() if p.requires_grad])

def set_flattened_params(model, new_flat):
    pointer = 0
    for p in model.parameters():
        if not p.requires_grad:
            continue
        numel = p.numel()
        new_vals = new_flat[pointer:pointer + numel].view_as(p).to(p.device)
        p.data.copy_(new_vals)
        pointer += numel

def random_direction_like(param_vector):
    direction = torch.randn_like(param_vector)
    return direction / direction.norm()

def evaluate_arc_easy_cli(model_args: str, limit: int = 20):
    import io
    import contextlib

    buffer = io.StringIO()
    with contextlib.redirect_stdout(buffer):
        sys.argv = [
            "dummy",
            "--model", "MambaPEFT",
            "--tasks", "arc_easy",
            "--model_args", model_args,
            "--limit", str(limit),
            # "--no_cache"
        ]
        cli_evaluate()
    output = buffer.getvalue()

    # Parse table output manually
    lines = output.strip().split('\n')
    arc_easy_line = None
    for line in lines:
        if line.strip().startswith("|arc_easy|"):
            arc_easy_line = line
            break

    if arc_easy_line is None:
        raise ValueError("Cannot find arc_easy result line in output")

    # Now parse the accuracy value from the line
    parts = arc_easy_line.split("|")
    parts = [p.strip() for p in parts if p.strip()]

    acc_value = float(parts[6])  # 7th column is accuracy value (index 6)

    # Treat loss = 1.0 - accuracy
    loss = 1.0 - acc_value
    return loss


def draw_loss_landscape_via_lm_eval(args, model, tokenizer, alphas=(-1, 1), betas=(-1, 1), steps=21, save_path="loss_landscape_eval.png"):
    orig_params = get_flattened_params(model).clone()

    direction_x = random_direction_like(orig_params)
    direction_y = random_direction_like(orig_params)

    alpha_range = np.linspace(*alphas, steps)
    beta_range = np.linspace(*betas, steps)

    loss_grid = np.zeros((steps, steps))

    for i, alpha in enumerate(alpha_range):
        for j, beta in enumerate(beta_range):
            new_params = orig_params + alpha * direction_x + beta * direction_y
            set_flattened_params(model, new_params)

            loss = evaluate_arc_easy_cli(
                model_args=f"peft_weights={args.peft_weights}",
                limit=args.eval_limit
            )
            loss_grid[i, j] = loss

    set_flattened_params(model, orig_params)

    A, B = np.meshgrid(alpha_range, beta_range)
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection="3d")
    ax.plot_surface(A, B, loss_grid, cmap="viridis", edgecolor='none')
    ax.set_xlabel("alpha")
    ax.set_ylabel("beta")
    ax.set_zlabel("Loss")
    ax.set_title("Loss Landscape via lm_eval on ARC-Easy")
    plt.tight_layout()
    print(f"Saving loss landscape to {save_path}")
    plt.savefig(save_path)
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--draw-landscape", action="store_true", help="Draw loss landscape instead of normal evaluation")
    parser.add_argument("--peft-weights", type=str, required=True, help="Path to PEFT adapter weights")
    parser.add_argument("--save-path", type=str, default="loss_landscape_eval.png")
    parser.add_argument("--eval-limit", type=int, default=20, help="How many ARC-Easy examples to use per evaluation")
    args, unknown = parser.parse_known_args()

    if args.draw_landscape:
        wrapper = MambaEvalWrapper(peft_weights=args.peft_weights)
        model = wrapper._model
        tokenizer = wrapper.tokenizer
        draw_loss_landscape_via_lm_eval(args, model, tokenizer, save_path=args.save_path)
    else:
        cli_evaluate()

# python loss_landscape.py --draw-landscape --peft-weights /gpfs/gibbs/pi/panda/dl2345/Research/MambaPEFT/language/commonsense_reasoning/results/lora_X
