import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import torch
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from transformers import HfArgumentParser

from nxm_transformer.admm_core import DEBUG_FILE_NAME
from nxm_transformer.debug_wrapper import DebugContainer
from nxm_transformer.admm_types import NxM_HEURISTIC_NAME




@dataclass
class SingleModelAnalysisArguments:

    path: str = field(
        metadata={"help": "Path to where the debug checkpoint was saved."}
    )

def validateRestoredObject(obj):
    if type(obj) != DebugContainer:
        raise Exception("Checkpoint path was not to a debug container.")
    elif not obj.validateStructure():
        raise Exception("Malformed DebugContainer restored.")

def main():
    parser = HfArgumentParser((SingleModelAnalysisArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        analysis_args, = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        analysis_args, = parser.parse_args_into_dataclasses()
    
    checkpoint = os.path.join(analysis_args.path, DEBUG_FILE_NAME)

    debug_wrapper = torch.load(checkpoint)
    validateRestoredObject(debug_wrapper)

    for name in debug_wrapper.debugParameters:
        z_list, w_list, mask_list = debug_wrapper.getLists(name)
        num_samples = len(mask_list)
        num_elements = mask_list[0].shape[0] * mask_list[0].shape[1]

        # Mask similarity
        shared_values = []

        with torch.no_grad():
            for i in range(len(mask_list) - 1):
                ratio_same = (mask_list[i] * mask_list[i+1]).sum() / mask_list[i].sum()
                shared_values.append(ratio_same)

        similarities = [i.numpy().reshape((1))[0] for i in shared_values]
        #print("Parameter: {}".format(name))
        #print(similarities)
        print(np.mean(similarities))


if __name__ == "__main__":
    main()
