import re
import os
import json
import random

import numpy as np
import torch
from transformers import TrainerState
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.patches import Rectangle, Polygon


def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["DATA_SEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

def get_model_storage_path(model_name):
    if os.path.isabs(model_name) or os.path.exists(model_name):
        return model_name
    CACHE_DIR = os.environ.get('TRANSFORMERS_CACHE')
    model_folder_path = os.path.join(CACHE_DIR, 'models--' + model_name.replace('/', '--'))
    
    if not os.path.exists(model_folder_path):
        raise FileNotFoundError(f"model folder path doesn't exist: {model_folder_path}")
    
    snapshots_dir = os.path.join(model_folder_path, 'snapshots')
    if not os.path.exists(snapshots_dir):
        raise FileNotFoundError(f"can not find snapshots dir: {snapshots_dir}")
    
    snapshot_subdirs = os.listdir(snapshots_dir)
    if not snapshot_subdirs:
        raise FileNotFoundError("no sub dir in ")
    
    snapshot_dir = os.path.join(snapshots_dir, snapshot_subdirs[0])
    return snapshot_dir


def save_state_and_model_for_hf_trainer(trainer):
    # save trainer state at trainer.args.output_dir path
    trainer.save_state()
    if trainer.args.should_save:
        # convert state_dict to cpu
        cpu_state_dict = {
            key: value.cpu()
            for key, value
            in trainer.model.state_dict().items()}
        
        trainer._save(
            trainer.args.output_dir, 
            state_dict=cpu_state_dict
        )

def load_state_and_model_for_hf_trainer(model, load_model_dir,
                                        map_location=None):
    # load model and trainer state from load_model_dir
    model.load_state_dict(torch.load(
        os.path.join(load_model_dir, "pytorch_model.bin"),
        map_location=map_location)
    )
    # model = model.from_pretrained(load_model_dir)
    trainer_state = TrainerState.load_from_json(
        os.path.join(load_model_dir, "trainer_state.json")
    )
    return model, trainer_state


def get_param_names_to_merge(input_param_names, exclude_param_names_regex):
    param_names_to_merge = []
    for param_name in input_param_names:
        exclude = any([
            re.match(exclude_pattern, param_name)
            for exclude_pattern
            in exclude_param_names_regex
        ])
        if not exclude:
            param_names_to_merge.append(param_name)
    return param_names_to_merge


def get_modules_to_merge(model, include_module_types):
    modules_to_merge = {}
    for module_name, module in model.named_modules():
        is_valid_type = not include_module_types or any(
            [isinstance(module, include_module_type) 
             for include_module_type 
             in include_module_types]
        )
        if is_valid_type:
            modules_to_merge[module_name] = module
    return modules_to_merge


# Define colors based on model names to ensure color consistency
model_color_map = {
    "vanillaOVO/WizardMath-13B-V1.0": '#ff9999',        
    "WizardLMTeam/WizardLM-13B-V1.2": '#66b3ff',      
    "layoric/llama-2-13b-code-alpaca": "#ffcc99", 
    "meta-llama/Llama-2-13b-hf": "#99ff99"
}


# plotting archs (quite ugly though)
def get_figure(slices, save_path, output_scales=None):
    # Calculate total layers for dynamic height adjustment
    total_layers = sum([slice_data["sources"][0]["layer_range"][1] - slice_data["sources"][0]["layer_range"][0] for slice_data in slices])
    
    # Set up the figure with dynamic size based on the total layers and extra width for labels
    fig_height = max(15, total_layers * 1.5) 
    fig_width = 15  
    fig, ax = plt.subplots(figsize=(fig_width, fig_height)) 

    ax.set_xlim(-2, 5)
    ax.set_ylim(0, total_layers * 1.5 + 5)

    new_layer_counter = 0
    for slice_idx, slice_data in enumerate(slices[::-1]):  # Draw slices[0] at the bottom in the correct order
        num_layers = slice_data["sources"][0]["layer_range"][1] - slice_data["sources"][0]["layer_range"][0]
        y_start = new_layer_counter * 1.5 + 1
        y_end = (new_layer_counter + num_layers - 1) * 1.5 + 2

        rect = Rectangle((1, y_start), 1, y_end - y_start, linewidth=1, edgecolor='black', facecolor='none')
        ax.add_patch(rect)

        num_sources = len(slice_data["sources"])
        for i, source in enumerate(slice_data["sources"]):
            color = model_color_map.get(source["model"], "#cccccc")
            if num_sources == 1:
                ax.add_patch(Rectangle((1, y_start), 1, y_end - y_start, color=color, alpha=0.5))
                ax.text(1.5, (y_start + y_end) / 2, f"{source['layer_range'][0]}-{source['layer_range'][1]}", 
                        ha='center', va='center', fontsize=8, color="black")
            elif num_sources == 2:
                if i == 0:
                    ax.add_patch(Polygon(((1, y_end), (2, y_start), (1, y_start)), color=color, alpha=0.5))
                    ax.text(1.25, (y_start + y_end) / 2, f"{source['layer_range'][0]}-{source['layer_range'][1]}", 
                            ha='center', va='center', fontsize=8, color="black")
                else:
                    ax.add_patch(Polygon(((1, y_end), (2, y_end), (2, y_start)), color=color, alpha=0.5))
                    ax.text(1.75, (y_start + y_end) / 2, f"{source['layer_range'][0]}-{source['layer_range'][1]}", 
                            ha='center', va='center', fontsize=8, color="black")
            elif num_sources == 3:
                height = (y_end - y_start) / 3
                ax.add_patch(Rectangle((1, y_start + i * height), 1, height, color=color, alpha=0.5))
                ax.text(1.5, y_start + (i + 0.5) * height, f"{source['layer_range'][0]}-{source['layer_range'][1]}", 
                        ha='center', va='center', fontsize=8, color="black")

        # Dynamically handle merging methods and stagger labels to avoid overlap
        method_text = ""
        for method, params in slice_data["merging_method"].items():
            method_text += f"{method}:\n"
            for key, param_set in params.items():
                method_text += f"{key}\n"
                for param in param_set:
                    param_info = " | ".join([f"{k}: {v}" for k, v in param.items()])
                    method_text += f"{param_info}\n"

        ax.text(-1, (y_start + y_end) / 2 + (slice_idx % 2) * 0.5, method_text.strip(), 
                ha='right', va='center', fontsize=8)

        ax.text(3.5, (y_start + y_end) / 2 - (slice_idx % 2) * 0.5, 
                f"New Layers {total_layers-new_layer_counter - num_layers}-{total_layers-new_layer_counter}", 
                ha='left', va='center', fontsize=8, color="black")
        
        if output_scales != None:
            ax.text(3.5, (y_start + y_end) / 2 - (slice_idx % 2) -1.5, 
                    f"New Layers scales {output_scales[total_layers-new_layer_counter - num_layers]}-{output_scales[total_layers-new_layer_counter-1]}", 
                    ha='left', va='center', fontsize=8, color="blue")

        new_layer_counter += num_layers

    for model_name, color in model_color_map.items():
        y_legend = fig_height - 3 - list(model_color_map.keys()).index(model_name) * 2
        ax.add_patch(Rectangle((3.5, y_legend), 0.5, 1, color=color, alpha=0.5))
        ax.text(4.2, y_legend + 0.5, model_name, va='center', fontsize=8)

    ax.axis('off')
    plt.gca().invert_yaxis()
    plt.savefig(save_path)
