import torch
import transformers
from transformers import AutoTokenizer
import sys
import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import io
import json
import contextlib
import glob
import uuid
import datetime
import copy
import shutil

sys.path.append(os.path.join(os.getcwd(), "mamba_peft/src/"))
from transformers import MambaForCausalLM


from vim_peft_all import VimPEFTModel

print(VimPEFTModel)




# === Loss Landscape Helpers ===

def get_flattened_params(model):
    params = []
    for n, p in model.named_parameters():
        if p.requires_grad and "lora" in n.lower():
            params.append(p.data.view(-1))

    if not params:
        raise ValueError("No LoRA parameters found in the model!")

    return torch.cat(params)


def set_flattened_params(model, new_flat):
    pointer = 0
    param_change_magnitudes = []
    changed_params = 0

    for n, p in model.named_parameters():
        if not p.requires_grad or "lora" not in n.lower():
            continue

        numel = p.numel()
        old_vals = p.data.clone()
        new_vals = new_flat[pointer:pointer + numel].view_as(p).to(p.device)

        # Calculate change magnitude
        param_diff = (new_vals - old_vals).abs().mean().item()
        param_change_magnitudes.append(
            (n, param_diff, old_vals.flatten()[:3].tolist(), new_vals.flatten()[:3].tolist()))

        # Update parameter
        p.data.copy_(new_vals)
        pointer += numel
        changed_params += 1

    # Print top 5 changes for debugging
    param_change_magnitudes.sort(key=lambda x: x[1], reverse=True)
    print(f"Changed {changed_params} parameters in total")
    print("Top 5 parameter changes:")
    for name, change, old, new in param_change_magnitudes[:5]:
        print(f"  {name}: diff={change:.6f}, old={old}, new={new}")


def random_direction_like(param_vector):
    direction = torch.randn_like(param_vector)
    return direction / direction.norm()


def evaluate_hellaswag_via_file(model_args: str, limit: int = 40, output_dir="hellaswag_temp_output"):
    # Create a unique output directory to avoid caching interference
    unique_id = uuid.uuid4()
    output_dir = f"{output_dir}_{unique_id}"
    os.makedirs(output_dir, exist_ok=True)

    print(f"Evaluating model with args: {model_args}")
    print(f"Current time: {datetime.datetime.now().isoformat()}")
    print(f"Using output directory: {output_dir}")

    # Delete any cache to ensure fresh evaluation
    cache_path = os.environ.get("LM_HARNESS_CACHE_PATH", os.path.join("lm_eval", "cache", ".cache"))
    if os.path.exists(cache_path):
        print(f"Deleting cache at {cache_path}")
        shutil.rmtree(cache_path, ignore_errors=True)

    # Clear GPU memory
    if hasattr(torch, 'cuda') and torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Set up arguments for CLI evaluation
    sys.argv = [
        "dummy",
        "--model", "MambaPEFT",
        "--tasks", "hellaswag",
        "--model_args", model_args,
        "--limit", str(limit),
        "--output_path", output_dir,
        # "--no_cache",  # Disable caching
    ]

    # Run evaluation
    cli_evaluate()

    # Find and read result file
    json_files = sorted(glob.glob(os.path.join(output_dir, "*.json")), key=os.path.getmtime)
    if not json_files:
        raise FileNotFoundError(f"No JSON result found in {output_dir}")

    latest_file = json_files[-1]
    with open(latest_file) as f:
        result = json.load(f)

    # Clean up output directory
    shutil.rmtree(output_dir, ignore_errors=True)

    # Extract accuracy
    hellaswag_results = result["results"].get("hellaswag", {})
    acc = hellaswag_results.get("acc_norm,none", None)
    if acc is None:
        print(f"Warning: 'acc_norm,none' not found. Defaulting to 1.0 loss.")
        return 1.0

    return 1.0 - acc  # proxy loss


def test_extreme_parameter_change(args):
    """Test function to verify that parameter changes affect evaluation"""
    print("\n===== TESTING EXTREME PARAMETER CHANGE =====\n")

    # Load the original model
    print("Loading base model...")
    base_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", trust_remote_code=True)
    print("Loading PEFT model...")
    peft_model = PeftModel.from_pretrained(base_model, args.peft_weights, torch_dtype=torch.float32)
    peft_model.to("cuda")
    peft_model.eval()

    # Save the original model for comparison
    original_dir = f"original_model_{uuid.uuid4()}"
    os.makedirs(original_dir, exist_ok=True)
    print(f"Saving original model to {original_dir}")
    peft_model.save_pretrained(original_dir)

    # Get parameters before modification
    print("Getting parameters before modification...")
    for name, param in peft_model.named_parameters():
        if "lora" in name.lower() and param.requires_grad:
            # Print a sample of the original parameter
            print(f"Original parameter {name}: {param.data.flatten()[:5]}")
            # Make an extreme change
            print(f"Making extreme change to {name}")
            param.data = param.data * 100  # Multiply by 100
            # Print the modified parameter
            print(f"Modified parameter {name}: {param.data.flatten()[:5]}")
            break

    # Save the modified model
    modified_dir = f"modified_model_{uuid.uuid4()}"
    os.makedirs(modified_dir, exist_ok=True)
    print(f"Saving modified model to {modified_dir}")
    peft_model.save_pretrained(modified_dir)

    # Free memory
    peft_model = peft_model.cpu()
    del peft_model
    del base_model
    torch.cuda.empty_cache()

    # Evaluate both models
    print("\nEvaluating original model...")
    orig_loss = evaluate_hellaswag_via_file(
        model_args=f"pretrained=state-spaces/mamba-130m-hf,peft={original_dir}",
        limit=args.eval_limit
    )

    print("\nEvaluating modified model...")
    mod_loss = evaluate_hellaswag_via_file(
        model_args=f"pretrained=state-spaces/mamba-130m-hf,peft={modified_dir}",
        limit=args.eval_limit
    )

    print(f"\nOriginal loss: {orig_loss:.6f}, Modified loss: {mod_loss:.6f}")
    print(f"Difference: {mod_loss - orig_loss:.6f}")

    # Clean up
    shutil.rmtree(original_dir, ignore_errors=True)
    shutil.rmtree(modified_dir, ignore_errors=True)

    return orig_loss, mod_loss


def draw_loss_landscape_via_checkpoint(args, alphas=(-10, 10), betas=(-10, 10), steps=5,
                                       save_path="loss_landscape_eval.png"):
    print("\n===== DRAWING LOSS LANDSCAPE =====\n")

    # Load the base model
    base_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", trust_remote_code=True)
    peft_model = PeftModel.from_pretrained(base_model, args.peft_weights, torch_dtype=torch.float32)
    peft_model.to("cuda")
    peft_model.eval()

    # Get original parameters
    try:
        orig_params = get_flattened_params(peft_model).clone()
        print(f"Original parameters shape: {orig_params.shape}")
    except ValueError as e:
        print(f"Error getting parameters: {e}")
        return

    # Create directions
    direction_x = random_direction_like(orig_params)
    direction_y = random_direction_like(orig_params)

    # Ensure orthogonality
    direction_y = direction_y - direction_y.dot(direction_x) * direction_x
    direction_y = direction_y / direction_y.norm()

    alpha_range = np.linspace(*alphas, steps)
    beta_range = np.linspace(*betas, steps)
    loss_grid = np.zeros((steps, steps))

    print(f"Evaluating {steps}x{steps} grid of points...")

    for i, alpha in enumerate(alpha_range):
        for j, beta in enumerate(beta_range):
            print(f"\n--- Point ({i + 1}/{steps}, {j + 1}/{steps}): α={alpha:.2f}, β={beta:.2f} ---")

            # Start fresh with the original model for each evaluation point
            base_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", trust_remote_code=True)
            peft_model = PeftModel.from_pretrained(base_model, args.peft_weights, torch_dtype=torch.float32)
            peft_model.to("cuda")
            peft_model.eval()

            # Modify parameters
            print("Modifying parameters...")
            new_params = orig_params + alpha * direction_x + beta * direction_y
            set_flattened_params(peft_model, new_params)

            # Generate a unique directory name for this checkpoint
            temp_peft_dir = f"temp_peft_{uuid.uuid4()}"
            os.makedirs(temp_peft_dir, exist_ok=True)

            # Save the modified model
            print(f"Saving modified model to {temp_peft_dir}")
            peft_model.save_pretrained(temp_peft_dir)

            # Force model to free GPU memory
            peft_model = peft_model.cpu()
            del peft_model
            del base_model
            torch.cuda.empty_cache()

            # Evaluate the saved checkpoint
            try:
                # Use the peft parameter, not peft_weights
                loss = evaluate_hellaswag_via_file(
                    model_args=f"pretrained=state-spaces/mamba-130m-hf,peft={temp_peft_dir}",
                    limit=args.eval_limit
                )
                print(f"α={alpha:.2f}, β={beta:.2f} → loss={loss:.4f}")
                loss_grid[i, j] = loss
            except Exception as e:
                print(f"Error at α={alpha:.2f}, β={beta:.2f}: {str(e)}")
                loss_grid[i, j] = float('nan')
            finally:
                # Clean up
                if os.path.exists(temp_peft_dir):
                    shutil.rmtree(temp_peft_dir, ignore_errors=True)

    # Create visualization
    A, B = np.meshgrid(alpha_range, beta_range)
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection="3d")
    surf = ax.plot_surface(A, B, loss_grid, cmap="viridis", edgecolor='none')
    fig.colorbar(surf)
    ax.set_xlabel("alpha")
    ax.set_ylabel("beta")
    ax.set_zlabel("Loss")
    ax.set_title("Loss Landscape via lm_eval on HellaSwag")
    plt.tight_layout()
    print(f"Saving loss landscape to {save_path}")
    plt.savefig(save_path)
    plt.close()

    # Save raw data
    np.savez(
        save_path.replace('.png', '.npz'),
        alpha_range=alpha_range,
        beta_range=beta_range,
        loss_grid=loss_grid
    )

    print(f"Saved raw data to {save_path.replace('.png', '.npz')}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--draw-landscape", action="store_true", help="Draw loss landscape")
    parser.add_argument("--test-extreme-change", action="store_true",
                        help="Test if extreme parameter changes affect evaluation")
    parser.add_argument("--peft-weights", type=str, required=True, help="Path to PEFT adapter weights")
    parser.add_argument("--save_path", type=str, default="loss_landscape_eval.png")
    parser.add_argument("--eval_limit", type=int, default=40)
    args = parser.parse_args()

    if args.test_extreme_change:
        test_extreme_parameter_change(args)
    elif args.draw_landscape:
        draw_loss_landscape_via_checkpoint(args, save_path=args.save_path)