import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import torch
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from transformers import HfArgumentParser

from nxm_transformer.admm_core import DEBUG_FILE_NAME
from nxm_transformer.debug_wrapper import DebugContainer
from nxm_transformer.admm_types import NxM_HEURISTIC_NAME

@dataclass
class SingleModelAnalysisArguments:

    path: str = field(
        metadata={"help": "Path to where the debug checkpoint was saved."}
    )

def validateRestoredObject(obj):
    if type(obj) != DebugContainer:
        raise Exception("Checkpoint path was not to a debug container.")
    elif not obj.validateStructure():
        raise Exception("Malformed DebugContainer restored.")

def main():
    parser = HfArgumentParser((SingleModelAnalysisArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        analysis_args, = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        analysis_args, = parser.parse_args_into_dataclasses()
    
    checkpoint = os.path.join(analysis_args.path, DEBUG_FILE_NAME)

    debug_wrapper = torch.load(checkpoint)
    validateRestoredObject(debug_wrapper)

    num_samples = 0
    all_decays = []
    for name in debug_wrapper.debugParameters:
        _, w_list, mask_list = debug_wrapper.getLists(name)
        num_samples = len(mask_list)

        occurrences = torch.zeros(mask_list[0].shape)
        with torch.no_grad():
            for tensor in mask_list:
                occurrences += tensor

        decay = []

        for frequency in range(num_samples+1):
            frequency_mask = occurrences == frequency

            average_values = []
            with torch.no_grad():
                for w in w_list:
                    modified_w = w * frequency_mask
                    average_values.append(modified_w.abs().sum() / modified_w.count_nonzero())

            decay.append((average_values[-1], average_values[0], frequency_mask.count_nonzero()))

        all_decays.append(decay)

    for frequency in range(num_samples+1):
        total_sum_start = 0
        total_sum_end = 0
        total_elems = 0

        for decay in all_decays:
            total_sum_start += decay[frequency][1] * decay[frequency][2]
            total_sum_end += decay[frequency][0] * decay[frequency][2]
            total_elems += decay[frequency][2]
        
        start_value = total_sum_start / total_elems
        end_value = total_sum_end / total_elems
        
        
        
        print("{},{},{}".format(frequency, start_value, end_value))

if __name__ == "__main__":
    main()
