import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import torch
import numpy as np

from transformers import HfArgumentParser

from nxm_transformer.admm_core import DEBUG_FILE_NAME
from nxm_transformer.debug_wrapper import DebugContainer
from nxm_transformer.admm_types import NxM_HEURISTIC_NAME

@dataclass
class DoubleModelAnalysisArguments:

    first_path: str = field(
        metadata={"help": "Path to where the first debug checkpoint was saved."}
    )

    second_path: str = field(
        metadata={"help": "Path to where the second debug checkpoint was saved."}
    )

def validateRestoredObject(obj):
    if type(obj) != DebugContainer:
        raise Exception("Checkpoint path was not to a debug container.")
    elif not obj.validateStructure():
        raise Exception("Malformed DebugContainer restored.")

def loadWrapper(path: str):
    checkpoint_path = os.path.join(path, DEBUG_FILE_NAME)
    wrapper = torch.load(checkpoint_path)
    if not validateRestoredObject(wrapper):
        raise Exception("Loaded debug wrapper malformed at {}".format(path))
    return wrapper

def main():
    parser = HfArgumentParser((DoubleModelAnalysisArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        analysis_args, = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        analysis_args, = parser.parse_args_into_dataclasses()
    
    first_wrapper = loadWrapper(analysis_args.first_path)
    second_wrapper = loadWrapper(analysis_args.second_path)

    if set(first_wrapper.debugParameters) != set(second_wrapper.debugParameters):
        raise Exception("Debugged parameters do not match between checkpoints, aborting.")

    for name in first_wrapper.debugParameters:
        first_z, first_w, first_mask = first_wrapper.getLists(name)
        second_z, second_w, second_mask = second_wrapper.getLists(name)

        if len(first_mask) != len(second_mask):
            raise Exception("Analysis currently only supported on experiments of same length")

        # Mask similarity
        shared_values = []

        with torch.no_grad():
            for first, second in zip(first_mask, second_mask):
                ratio_same = (first * second).sum() / first.sum()
                shared_values.append(ratio_same)
        print("Parameter: {}".format(name))
        print(shared_values, end="\n\n")

if __name__ == "__main__":
    main()
