import argparse
import pandas as pd
from pathlib import Path

import yaml

from cheem.config_utils import Config

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

group = parser.add_argument_group('Dataset parameters')
group.add_argument('--runtime-config-file', default='', type=str, metavar='RUNTIME_CONFIG',
                   help='YAML config file specifying ImageNet arguments')
group.add_argument('--benchmark', default='vdd', type=str, metavar='BENCHMARK',
                   help='Benchmark to run. Should be one of "vdd", "5-datasets"')
group.add_argument('--root_dir', metavar='ROOT',
                   help='root directory to the repository')

group = parser.add_argument_group('Runtime settings')
group.add_argument('--task-idx', default=0, type=int,
                   help='Task index to visualize')
group.add_argument('--wandb-project', default=None, type=str,
                   help='WandB project name')
group.add_argument('--wandb-username', default=None, type=str,
                   help='WandB username')
group.add_argument('--exp-name', default=None, type=str,
                   help='WandB run group.')
group.add_argument('--epochs', default=None, type=int,
                   help='WandB run group.')
group.add_argument('--sampling_scheme', default=None, type=str,
                   help='WandB run group.')
group.add_argument('--seed', default=None, type=int,
                   help='WandB run group.')
group.add_argument('--seq', default=None, type=int,
                   help='WandB run group.')
group.add_argument('--cheem-component', default=None, type=str,
                   help='WandB run group.')

# Parse the args and save in variable args
args = parser.parse_args()


_config = Config(
    args.benchmark, args.runtime_config_file, args.root_dir, args.exp_name, 
    wandb_username=args.wandb_username, wandb_project=args.wandb_project, 
    checkpoint_path=None)
runtime = _config.runtime
benchmark_config = _config.benchmark_config
evolutionary_search_config = _config.evolutionary_search_config

# Load the config yaml file with safe_load
arch_config_path = Path(runtime.config_dir, f"task_{args.task_idx}.yaml")
with open(arch_config_path, 'r') as f:
    arch_config = yaml.safe_load(f)

# Number of params for each operation (reuse is absorbed into the reused expert)

# vit_base_params
embed_params = 82 * 768
patch_embed = 8*8*3*768 + 768
cls_token = 768
attn_head = 768*768*3 + 768*3 \
    + 768*768 + 768 + 768*2
mlp = 768*768*4 + 768*4 \
    + 768*4*768 + 768 + 768*2
last_ln = 768*2
vit_base_params = embed_params + patch_embed + cls_token + 12*(attn_head + mlp) + last_ln

if args.cheem_component in ["attn_proj", "value", "query", "key"]:
    op_num_params = {
        "adapt": 768*(768//runtime.downscale) + (768//runtime.downscale) + (768//runtime.downscale)*768 + 768,
        "new": 768*768 + 768,
        "skip_dec": -attn_head,
        "skip": 0,
        "identity": 0
    }
elif args.cheem_component == "ffn":
    op_num_params = {
        "adapt": 768*(768//runtime.downscale) + (768//runtime.downscale) + (768//runtime.downscale)*768 + 768,
        "new": 768*768*4 + 768*4 + 768*4*768 + 768,
        "skip_dec": -(768*768*4 + 768*4 + 768*4*768 + 768),
        "skip": 0,
        "identity": 0
    }

task_order = runtime.task_order

task_operations = {t: [] for t in task_order}

expert_to_task_map = [dict() for i in range(len(arch_config))]

# Start iterating over layers
additional_parameters = 0
additional_parameters_wo_skip = 0

taskwise_inc = {t+1: 0 for t in range(len(task_order[1:]))}
taskwise_inc_wo_skip = {t+1: 0 for t in range(len(task_order[1:]))}

for layer, experts in enumerate(arch_config):
    for expert_id, expert_config in experts.items():
        primitive = expert_config["primitive"]
        additional_parameters += op_num_params[primitive]
        additional_parameters_wo_skip += op_num_params[primitive.replace("skip", "skip_dec")]
        # Get how many tasks reuse this expert
        associated_tasks = expert_config["associated_tasks"]

        if associated_tasks[0] > 0:
            taskwise_inc[associated_tasks[0]] += op_num_params[primitive]
            taskwise_inc_wo_skip[associated_tasks[0]] += op_num_params[primitive.replace("skip", "skip_dec")]

        # Map
        associated_task_name = expert_to_task_map[layer].get(expert_id, None)
        if associated_task_name is None:
            expert_to_task_map[layer][expert_id] = task_order[associated_tasks[0]]

        # Which does this expert originally belong to?
        original_task = task_order[associated_tasks[0]]
        # If this expert is reused, then which tasks reuse this?
        if len(associated_tasks) > 1:
            for task_idx in associated_tasks[1:]:
                task_name = task_order[task_idx]
                task_operations[task_name].append(f"reuse_{original_task}")

        if primitive == "adapt":
            # If adapt, which task is this adapted from?
            parent_expert_id = expert_config["parent_expert_id"]
            parent_task = expert_to_task_map[layer][parent_expert_id]
            task_operations[original_task].append(f"adapt_{parent_task}")
        elif primitive in ["new", "skip", "identity"]:
            task_operations[original_task].append(primitive)

avg_inc = (sum(taskwise_inc.values())/len(taskwise_inc)) / 1e6
avg_inc_wo_skip = (sum(taskwise_inc_wo_skip.values())/len(taskwise_inc_wo_skip)) / 1e6

header = """
digraph G {
    fontname="Helvetica,Arial,sans-serif"
	fontsize=25;
	node [fontname="Helvetica,Arial,sans-serif"]
	edge [fontname="Helvetica,Arial,sans-serif"]

	rankdir=LR;    
    compound=true;
    newrank=true;
	margin=0;

	subgraph cluster_0 {
		style="filled,solid";
		fillcolor="#E6E6E6";
		color="#FFFFFF";
		weight=15;
		margin=8;
		label = "Tsk1_ImNet";
		labeljust=l;
		node [style="filled,solid",color=black,weight="2",fillcolor=aliceblue,fontsize=23];
		edge [color="black",shape="vee"]
		IN_B1 -> IN_B2 -> IN_B3 -> IN_B4 -> IN_B5 -> IN_B6 -> IN_B7 -> IN_B8 -> IN_B9 -> IN_B10 -> IN_B11 -> IN_B12;
		
	}

	node [shape=box]
"""

tail = """
    { rank=same; B1; C100_B1; SVHN_B1; UCF_B1; OGlot_B1; GTSR_B1; DPed_B1; Flwr_B1; Airc_B1; DTD_B1; }

}
"""

def generate_subcluster(id, task_num, code, blocks, color, edges):

    h = "\tsubgraph cluster_{} {{\n\t\tstyle=\"filled,solid\";\n\t\tfillcolor=\"#E6E6E6\"\n\t\tcolor=\"#FFFFFF\";\n\t\tweight=15;\n\t\tmargin=8;\n\t\tlabel=\"Tsk{}_{}\";\n\t\tlabeljust=l\n\t\t".format(id, task_num, code)
    h = h + "\n\t\t".join(blocks) + "\n"
    h = h + "\t\tedge [color=\"black\",shape=\"vee\"]\n\t\t"
    h = h + " -> ".join([f"{code}_B{i}" for i in range(1, 13)]) + ";\n\t}\n\t\t"

    # Edges
    h = h + f"\n\tedge [color=\"{color}\",shape=\"normal\"]\n\t" + "\n\t".join(edges)

    return h

task_to_nodename = {
    "imagenet12": "IN", 
    "cifar100": "C100",
    "svhn": "SVHN",
    "ucf101": "UCF",
    "omniglot": "OGlot",
    "gtsrb": "GTSR",
    "daimlerpedcls": "DPed",
    "vgg-flowers": "Flwr",
    "aircraft": "Airc",
    "dtd": "DTD"    
}

colors = [
    "blue", "burlywood4", "cadetblue4", "blueviolet", "brown4", "cyan2", 
    "darkolivegreen", "darkslateblue", "deeppink3"
]

blocks = {task: [] for task in task_order[1:]}
edges = {task: [] for task in task_order[1:]}

df = pd.DataFrame(task_operations)

task_operations = {task_order[0]: ["identity" for _ in range(len(arch_config))]}
task_operations.update({t: [] for t in task_order[1:]})
# Iter over layers
for i, r in enumerate(df.to_dict(orient="records")):
    task_origins = {task: task for task in task_order}
    for task_idx, task_name in enumerate(task_order[1:]):
        _block = task_to_nodename[task_name] + f"_B{i+1}"
        expert = r[task_name]
        components = expert.split("_")
        primitive = components[0]
        _task = components[1] if len(components) == 2 else None

        if primitive == "reuse":
            reused_task = _task
            # Add edge
            origin_task = task_origins[reused_task]
            start = f"{task_to_nodename[origin_task]}_B{i+1}"
            end = task_to_nodename[task_name] + f"_B{i+1}"
            edge = " -> ".join([start, end]) + ";"
            task_origins[reused_task] = task_name
            label = "R"
            color = "#06592A"
            fillcolor = "#9CCEA7"
            edges[task_name].append(edge)
        elif primitive == "adapt":
            adapted_task = _task
            start = task_origins[adapted_task]
            start = task_to_nodename[start] + f"_B{i+1}"
            end = task_to_nodename[task_name] + f"_B{i+1}"
            edge = " -> ".join([start, end]) + ";"
            label = "A"
            color = "#E31A1C"
            fillcolor = "#FEB24C"
            edges[task_name].append(edge)
        elif "skip" in expert:
            label = "S"
            color = "#C40F5B"
            fillcolor = "#F2ACCA"
        elif "new" in expert:
            label = "N"
            color = "#226E9C"
            fillcolor = "#9EC9E2"
        _block = _block + f" [label=\"{label}\",style=\"filled,solid\",color=\"{color}\",weight=\"2\",fillcolor=\"{fillcolor}\",fontsize=23];"
        blocks[task_name].append(_block)

clusters = []
for i, task in enumerate(task_order[1:]):

    code = task_to_nodename[task]
    c = generate_subcluster(i+1, i+2, code, blocks[task], colors[i], edges[task])
    clusters.append(c)

clusters = "\n \n".join(clusters)

# tail = "\t{ rank=same; IN_B1; " + " ".join([task_to_nodename[task]+"_B1;" for task in task_order[1:]]) + " }\n}"
tail = "\n".join(["\t{{rank=same; IN_B{}; ".format(i) + " ".join([task_to_nodename[task]+f"_B{i};" for task in task_order[1:]]) + " }" for i in range(1, 13)]) + "\n}"

final = "\n \n".join([header, clusters, tail])

vis_path = Path(runtime.root_path, "statistics_and_vis")
vis_path.mkdir(exist_ok=True, parents=True)
dot_path = Path(vis_path, f"{args.exp_name}.dot")
with open(dot_path, "w") as f:
    f.write(final)

stats_df_path = Path(vis_path, f"{args.cheem_component}-statistics.csv")
args.sampling_scheme = args.sampling_scheme.replace("ee-w-prompt", "EE w/ Prompt")
if args.sampling_scheme == "ee" or args.sampling_scheme == "e":
    args.sampling_scheme = args.sampling_scheme.upper()

args.sampling_scheme = args.sampling_scheme + f"-{args.cheem_component}"
_df = pd.DataFrame({
    "Seq": [args.seq], "Supernet Training Epochs": [args.epochs], "Sampling Scheme": [args.sampling_scheme], "Seed": args.seed, 
    "Average Accuracy": [""],
    "Additional Parameters": [additional_parameters], "Additional Parameters w/ skip": [additional_parameters_wo_skip], 
    "% Inc": [100*(additional_parameters/vit_base_params)], "% Inc w/ skip": [100*(additional_parameters_wo_skip/vit_base_params)], 
    "Avg. Param Inc./Task (M)": [avg_inc], "Avg. Param Inc./Task (M) w/ skip": [avg_inc_wo_skip]})

if stats_df_path.exists():
    stat_df = pd.read_csv(stats_df_path)
    stat_df = pd.concat([stat_df, _df], axis=0)
    stat_df = stat_df.sort_values(by=["Seq", "Sampling Scheme", "Supernet Training Epochs", "Seed"])
    stat_df.reset_index(drop=True, inplace=True)
else:
    stat_df = _df

# print(df)
# print(f"Additional parameters: {additional_parameters}")
# print(f"%increase: {additional_parameters/vit_base_params*100}")
stat_df.to_csv(stats_df_path, index=False)
