import os
import pathlib
import random
import numpy as np
import math

import torch

from args import args
from models.module_util import get_subnet, mask_overlap
import utils

from collections import defaultdict

def main():
    # Make the a directory corresponding to this run for saving results, checkpoints etc.
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    args.run_base_dir = args.log_dir

    print(f"=> Saving data in {args.run_base_dir}")

    # Setup where you use n mask(s) from a given task from n different model runs.
    print('Seed model format: {}'.format(args.seed_model_format))

    # Assert at each level that the backbone network remains unchanged.
    # Iterate and load n models, and update state dict parameter with the params of a given task into this model.
    model_path = args.seed_model_format.format(sparsity=int(args.sparsity), seed=args.seed)
    if os.path.isfile(model_path):
        print(f"=> Loading seed model from '{model_path}'")
        checkpoint = torch.load(
            model_path, map_location=f"cuda:{args.multigpu[0]}" if torch.cuda.is_available() else torch.device('cpu')
        )
    else:
        raise RuntimeError(f"=> No seed model found at '{model_path}'!")


    if isinstance(checkpoint, dict) and checkpoint['args'].model.endswith('ResNet18'):
        # Formats: module.conv1.weight or module.conv1.scores.4
        # also for BatchN: module.layerX.X.conv{1,2} ==> module.layerX.X.bn{1,2}.bns.{running_mean, running_va, num_batches_tracked}
        layer_fmt = ['module.conv1.{}',
                      'module.layer1.0.conv1.{}',
                      'module.layer1.0.conv2.{}',
                      'module.layer1.1.conv1.{}',
                      'module.layer1.1.conv2.{}',
                      'module.layer2.0.conv1.{}',
                      'module.layer2.0.conv2.{}',
                      'module.layer2.0.shortcut.0.{}',
                      'module.layer2.1.conv1.{}',
                      'module.layer2.1.conv2.{}',
                      'module.layer3.0.conv1.{}',
                      'module.layer3.0.conv2.{}',
                      'module.layer3.0.shortcut.0.{}',
                      'module.layer3.1.conv1.{}',
                      'module.layer3.1.conv2.{}',
                      'module.layer4.0.conv1.{}',
                      'module.layer4.0.conv2.{}',
                      'module.layer4.0.shortcut.0.{}',
                      'module.layer4.1.conv1.{}',
                      'module.layer4.1.conv2.{}',
                      'module.linear.{}']
        sparsity = float(checkpoint['args'].sparsity)
        state_dict = checkpoint['state_dict']
    elif isinstance(checkpoint, dict) and  checkpoint['args'].model.endswith('ResNet50'):
        layer_fmt = ['module.conv1.{}',
          'module.fc.{}',
          'module.layer1.0.conv1.{}',
          'module.layer1.0.conv2.{}',
          'module.layer1.0.conv3.{}',
          'module.layer1.0.downsample.0.{}',
          'module.layer1.1.conv1.{}',
          'module.layer1.1.conv2.{}',
          'module.layer1.1.conv3.{}',
          'module.layer1.2.conv1.{}',
          'module.layer1.2.conv2.{}',
          'module.layer1.2.conv3.{}',
          'module.layer2.0.conv1.{}',
          'module.layer2.0.conv2.{}',
          'module.layer2.0.conv3.{}',
          'module.layer2.0.downsample.0.{}',
          'module.layer2.1.conv1.{}',
          'module.layer2.1.conv2.{}',
          'module.layer2.1.conv3.{}',
          'module.layer2.2.conv1.{}',
          'module.layer2.2.conv2.{}',
          'module.layer2.2.conv3.{}',
          'module.layer2.3.conv1.{}',
          'module.layer2.3.conv2.{}',
          'module.layer2.3.conv3.{}',
          'module.layer3.0.conv1.{}',
          'module.layer3.0.conv2.{}',
          'module.layer3.0.conv3.{}',
          'module.layer3.0.downsample.0.{}',
          'module.layer3.1.conv1.{}',
          'module.layer3.1.conv2.{}',
          'module.layer3.1.conv3.{}',
          'module.layer3.2.conv1.{}',
          'module.layer3.2.conv2.{}',
          'module.layer3.2.conv3.{}',
          'module.layer3.3.conv1.{}',
          'module.layer3.3.conv2.{}',
          'module.layer3.3.conv3.{}',
          'module.layer3.4.conv1.{}',
          'module.layer3.4.conv2.{}',
          'module.layer3.4.conv3.{}',
          'module.layer3.5.conv1.{}',
          'module.layer3.5.conv2.{}',
          'module.layer3.5.conv3.{}',
          'module.layer4.0.conv1.{}',
          'module.layer4.0.conv2.{}',
          'module.layer4.0.conv3.{}',
          'module.layer4.0.downsample.0.{}',
         'module.layer4.1.conv1.{}',
         'module.layer4.1.conv2.{}',
         'module.layer4.1.conv3.{}',
         'module.layer4.2.conv1.{}',
         'module.layer4.2.conv2.{}',
         'module.layer4.2.conv3.{}']
        state_dict = checkpoint['state_dict']
        sparsity = float(checkpoint['args'].sparsity)
    else:
        layer_fmt = [
            'model.0.{}',
            'model.2.{}',
            'model.4.{}'
        ]
        assert 'sparsity' in model_path
        sparsity = float([p for p in model_path.split('~') if 'sparsity' in p][0].split('=')[-1])
        state_dict = checkpoint.state_dict()

    csv_writer = utils.CsvWriter(colnames=['task_i', 'task_j', 'layer', 'overlap', 'sparsity', 'model', 'save_dir'], prefix=f'{args.prefix}_{args.sparsity}')

    for task_i in range(0, args.num_tasks):
        for task_j in range(task_i, args.num_tasks):
            for layer in layer_fmt:
                layer_i = layer.format('scores.{}'.format(task_i))
                layer_j = layer.format('scores.{}'.format(task_j))
                mask_i = get_subnet(state_dict[layer_i], sparsity/100)
                mask_j = get_subnet(state_dict[layer_j], sparsity/100)
                overlap = mask_overlap(mask_i, mask_j)
                csv_writer.write(task_i=task_i,
                                 task_j=task_j,
                                 layer=layer.format(''),
                                 overlap = "{overlap: 04f}".format(overlap=overlap),
                                 sparsity = sparsity,
                                 model = model_path,
                                 save_dir = os.path.dirname(csv_writer.results)
                                )
                print(
                    f"Overlap between mask {task_i} and {task_j} for layer {layer} is {overlap: 04f}"
                    )


    return

if __name__ == "__main__":
    main()
