import argparse
import pandas as pd
from pathlib import Path

import yaml

from cheem.config_utils import Config

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

group = parser.add_argument_group('Dataset parameters')
group.add_argument('--runtime-config-file', default='', type=str, metavar='RUNTIME_CONFIG',
                   help='YAML config file specifying ImageNet arguments')
group.add_argument('--benchmark', default='vdd', type=str, metavar='BENCHMARK',
                   help='Benchmark to run. Should be one of "vdd", "5-datasets"')
group.add_argument('--root_dir', metavar='ROOT',
                   help='root directory to the repository')

group = parser.add_argument_group('Runtime settings')
group.add_argument('--task-idx', default=0, type=int,
                   help='Task index to visualize')
group.add_argument('--wandb-project', default=None, type=str,
                   help='WandB project name')
group.add_argument('--wandb-username', default=None, type=str,
                   help='WandB username')
group.add_argument('--exp-name', default=None, type=str,
                   help='WandB run group.')
group.add_argument('--epochs', default=None, type=int,
                   help='WandB run group.')
group.add_argument('--sampling_scheme', default=None, type=str,
                   help='WandB run group.')
group.add_argument('--seed', default=None, type=int,
                   help='WandB run group.')
group.add_argument('--seq', default=None, type=int,
                   help='WandB run group.')

# Parse the args and save in variable args
args = parser.parse_args()


_config = Config(
    args.benchmark, args.runtime_config_file, args.root_dir, args.exp_name, 
    wandb_username=args.wandb_username, wandb_project=args.wandb_project, 
    checkpoint_path=None)
runtime = _config.runtime
benchmark_config = _config.benchmark_config
evolutionary_search_config = _config.evolutionary_search_config

# Load the config yaml file with safe_load
arch_config_path = Path(runtime.config_dir, f"task_{args.task_idx}.yaml")
with open(arch_config_path, 'r') as f:
    arch_config = yaml.safe_load(f)

# Number of params for each operation (reuse is absorbed into the reused expert)

# vit_base_params
embed_params = 82 * 768
patch_embed = 8*8*3*768 + 768
cls_token = 768
attn_head = 768*768*3 + 768*3 \
    + 768*768 + 768 + 768*2
mlp = 768*768*4 + 768*4 \
    + 768*4*768 + 768 + 768*2
last_ln = 768*2
vit_base_params = embed_params + patch_embed + cls_token + 12*(attn_head + mlp) + last_ln

op_num_params = {
    "adapt": 768*(768//runtime.downscale) + (768//runtime.downscale) + (768//runtime.downscale)*768 + 768,
    "new": 768*768 + 768,
    "skip_dec": -attn_head,
    "skip": 0,
    "identity": 0
}

task_order = runtime.task_order

task_operations = {t: [] for t in task_order}

expert_to_task_map = [dict() for i in range(len(arch_config))]

# Start iterating over layers
additional_parameters = 0
additional_parameters_wo_skip = 0

taskwise_inc = {t+1: 0 for t in range(len(task_order[1:]))}
taskwise_inc_wo_skip = {t+1: 0 for t in range(len(task_order[1:]))}

for layer, experts in enumerate(arch_config):
    for expert_id, expert_config in experts.items():
        primitive = expert_config["primitive"]
        additional_parameters += op_num_params[primitive]
        additional_parameters_wo_skip += op_num_params[primitive.replace("skip", "skip_dec")]
        # Get how many tasks reuse this expert
        associated_tasks = expert_config["associated_tasks"]

        if associated_tasks[0] > 0:
            taskwise_inc[associated_tasks[0]] += op_num_params[primitive]
            taskwise_inc_wo_skip[associated_tasks[0]] += op_num_params[primitive.replace("skip", "skip_dec")]

        # Map
        associated_task_name = expert_to_task_map[layer].get(expert_id, None)
        if associated_task_name is None:
            expert_to_task_map[layer][expert_id] = task_order[associated_tasks[0]]

        # Which does this expert originally belong to?
        original_task = task_order[associated_tasks[0]]
        # If this expert is reused, then which tasks reuse this?
        if len(associated_tasks) > 1:
            for task_idx in associated_tasks[1:]:
                task_name = task_order[task_idx]
                task_operations[task_name].append(f"reuse_{original_task}")

        if primitive == "adapt":
            # If adapt, which task is this adapted from?
            parent_expert_id = expert_config["parent_expert_id"]
            parent_task = expert_to_task_map[layer][parent_expert_id]
            task_operations[original_task].append(f"adapt_{parent_task}")
        elif primitive in ["new", "skip", "identity"]:
            task_operations[original_task].append(primitive)

vis_path = Path(runtime.root_path, "statistics_and_vis")

stats_df_path = Path(vis_path, f"taskwise_statistics.csv")
args.sampling_scheme = args.sampling_scheme.replace("ee-w-prompt", "EE w/ Prompt")
if args.sampling_scheme == "ee" or args.sampling_scheme == "e":
    args.sampling_scheme = args.sampling_scheme.upper()

stats = {"Seq": [args.seq], "Supernet Training Epochs": [args.epochs], "Sampling Scheme": [args.sampling_scheme], "Seed": args.seed}
_sum = 0
for task_idx, task_name in enumerate(task_order[1:]):
    _sum = _sum + taskwise_inc[task_idx+1]
    stats[str(task_idx+1)] = [_sum]
_df = pd.DataFrame(stats)

if stats_df_path.exists():
    stat_df = pd.read_csv(stats_df_path)
    stat_df = pd.concat([stat_df, _df], axis=0)
    stat_df = stat_df.sort_values(by=["Seq", "Sampling Scheme", "Supernet Training Epochs", "Seed"])
    stat_df.reset_index(drop=True, inplace=True)
else:
    stat_df = _df

# print(df)
# print(f"Additional parameters: {additional_parameters}")
# print(f"%increase: {additional_parameters/vit_base_params*100}")
stat_df.to_csv(stats_df_path, index=False)
