import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from ..models.wrappers import Supernet

try:
    import wandb
    has_wandb = True
except:
    has_wandb = False


def plot(df: pd.DataFrame, current_task_name, layer_names, save_path, name=None, min_lim=-1.1, max_lim=1.1, label=None, log_wandb=False):

    fig, ax = plt.subplots()
    sns.lineplot(data=df, markers=True, ax=ax)

    ax.set_ylim(min_lim, max_lim)
    ax.set_xticks(list(range(df.shape[0])))
    ax.set_xticklabels(layer_names, rotation=90)
    ax.set_title(f"{current_task_name}")
    ax.set_xlabel("Block")
    ax.set_ylabel(label)

    _save_path = str(save_path)
    fig.savefig(_save_path, dpi=300, bbox_inches="tight")
    _save_path = _save_path.replace(".pdf", ".png")
    fig.savefig(_save_path, dpi=300, bbox_inches="tight")

    if log_wandb:
        wandb.log({name: wandb.Image(ax)})

    plt.close() 


def experts_to_tasks(expert_ops, expert_ids, quantity, tasks):

    modified_quantity = np.zeros((len(tasks),))
    modified_quantity.fill(np.nan)
    for i, expert_id in enumerate(expert_ids):
        associated_task = expert_ops[expert_id].associated_tasks[0]
        modified_quantity[associated_task] = quantity[i]
    return modified_quantity


def plot_block_similarities(model: Supernet, task_idx, task_order, save_root, normalize_similarities=True, log_wandb=False):

    tasks = task_order[:task_idx]
    quantities = ["similarities", "expert_sampling_probabilities", "retention_probabilities"]
    if normalize_similarities:
        quantities += ["normalized_similarities"]

    logs = {key: {_task: [] for _task in tasks} for key in quantities}

    layer_names = []

    for layer, (layer_name, mean_layer, dim) in enumerate(model.iter_backbone(model.backbone, mode="statistics")):
        
        if mean_layer is None:
            continue
        layer_names.append(layer_name)
        layer_log = dict()
        expert_ids = model.nas_expert_list[layer]
        similarities = model._similarities[layer].cpu().numpy()
        expert_sampling_probabilities = model.samplers[layer].expert_sampling_prob
        retention_probabilities = model.samplers[layer].retention_prob

        layer_log["similarities"] = experts_to_tasks(model.experts[layer], expert_ids, similarities, tasks)
        layer_log["expert_sampling_probabilities"] = experts_to_tasks(model.experts[layer], expert_ids, expert_sampling_probabilities, tasks)
        layer_log["retention_probabilities"] = experts_to_tasks(model.experts[layer], expert_ids, retention_probabilities, tasks)

        if normalize_similarities:
            similarities = model._normalized_similarities[layer].cpu().numpy()
            layer_log["normalized_similarities"] = experts_to_tasks(model.experts[layer], expert_ids, similarities, tasks)

        for key, quant in layer_log.items():
            for _task_idx, (task, _quant) in enumerate(zip(tasks, quant)):
                logs[key][task].append(quant[_task_idx])

    dataframes = {
        k: pd.DataFrame(v) for k, v in logs.items()
    }

    labels = {
        "similarities": "Cosine Similarity",
        "normalized_similarities": "Normalized Cosine Similarity",
        "expert_sampling_probabilities": "Expert Sampling Prob.",
        "retention_probabilities": "Retention Prob"
    }

    for k, df in dataframes.items():
        
        current_task_name = task_order[task_idx]
        
        save_path = Path(save_root, f"{k}.csv")
        df.to_csv(save_path, index=False, header=True)
        if has_wandb and log_wandb:
            wandb.log({f"tab_{current_task_name}_{k}": wandb.Table(dataframe=df)})
        
        # Plot
        fig_save_path = Path(save_root, f"{k}.pdf")
        min_lim = -0.1 if "probabilities" in k else -1.1
        max_lim = 1.1
        plot(df, current_task_name, layer_names, fig_save_path, name=f"{current_task_name}_{k}", label=labels[k], min_lim=min_lim, max_lim=max_lim, log_wandb=log_wandb)


def plot_t2t_block_similarities(model: Supernet, task_idx, task_order, save_root, normalize_similarities=True, log_wandb=False):

    tasks = task_order[:1]
    quantities = ["similarities", "expert_sampling_probabilities", "retention_probabilities"]
    if normalize_similarities:
        quantities += ["normalized_similarities"]

    logs = {key: {_task: [] for _task in tasks} for key in quantities}

    layer_names = []

    for layer, (layer_name, mean_layer, dim) in enumerate(model.iter_backbone(model.backbone, mode="statistics")):
        
        if mean_layer is None:
            continue
        layer_names.append(layer_name)
        layer_log = dict()
        expert_ids = model.nas_expert_list[layer]
        similarities = model._similarities[layer].cpu().numpy()
        expert_sampling_probabilities = model.samplers[layer].expert_sampling_prob
        retention_probabilities = model.samplers[layer].retention_prob

        layer_log["similarities"] = experts_to_tasks(model.experts[layer], expert_ids, similarities, tasks)
        layer_log["expert_sampling_probabilities"] = experts_to_tasks(model.experts[layer], expert_ids, expert_sampling_probabilities, tasks)
        layer_log["retention_probabilities"] = experts_to_tasks(model.experts[layer], expert_ids, retention_probabilities, tasks)

        if normalize_similarities:
            similarities = model._normalized_similarities[layer].cpu().numpy()
            layer_log["normalized_similarities"] = experts_to_tasks(model.experts[layer], expert_ids, similarities, tasks)

        for key, quant in layer_log.items():
            for _task_idx, (task, _quant) in enumerate(zip(tasks, quant)):
                logs[key][task].append(quant[_task_idx])

    dataframes = {
        k: pd.DataFrame(v) for k, v in logs.items()
    }

    labels = {
        "similarities": "Cosine Similarity",
        "normalized_similarities": "Normalized Cosine Similarity",
        "expert_sampling_probabilities": "Expert Sampling Prob.",
        "retention_probabilities": "Retention Prob"
    }

    for k, df in dataframes.items():
        
        current_task_name = task_order[task_idx]
        
        save_path = Path(save_root, f"{k}.csv")
        df.to_csv(save_path, index=False, header=True)
        if has_wandb and log_wandb:
            wandb.log({f"tab_{current_task_name}_{k}": wandb.Table(dataframe=df)})
        
        # Plot
        fig_save_path = Path(save_root, f"{k}.pdf")
        min_lim = -0.1 if "probabilities" in k else -1.1
        max_lim = 1.1
        plot(df, current_task_name, layer_names, fig_save_path, name=f"{current_task_name}_{k}", label=labels[k], min_lim=min_lim, max_lim=max_lim, log_wandb=log_wandb)
