"""
Inspect network parameters and architecture using PyTorch.
"""
__author__ = 'XYZ'


import json
import os
import re
import platform
import time

from copy import deepcopy

import pandas as pd

try:
  import torch
  ## clear memory
  torch.cuda.empty_cache()
except ImportError:
  print('torch is not installed')

from ..core import encoders


def get_device(args):
  return 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'


def gather_system_info():
  return {
    "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
    "torch_version": torch.__version__,
    "cuda_version": torch.version.cuda if torch.cuda.is_available() else None
  }


def check_forward_pass(model:torch.nn.Module, input_size:tuple):
  """Check if the model can run a forward pass."""
  img = None
  try:
    ## Create dummy input
    img = torch.zeros((1, *input_size), device=next(model.parameters()).device)
    output = model(img)  ## Run a forward pass
    print(f"Model forward pass output shape: {output.shape}")  # Check output shape
  except Exception as e:
    print(f"Error during model forward pass: {e}")
    img = None
  return img


def calc_fps(model:torch.nn.Module, input_size:tuple, num_iterations:int=100):
  img = check_forward_pass(model, input_size)
  ## FLOPS Calculation using profiling
  try:
    ## To calculate FPS, let's time the forward pass over a number of iterations
    start_time = time.time()
    
    if torch.cuda.is_available():
      torch.cuda.synchronize()  ## Ensure all CUDA operations are finished
    
    ## Run iterations for averaging time
    for _ in range(num_iterations):
      model(img)  ## Perform forward pass
    
    if torch.cuda.is_available():
      torch.cuda.synchronize()  ## Ensure all CUDA operations are finished
        
    end_time = time.time()
    
    ## Calculate average time per forward pass (in seconds)
    avg_time_per_pass = (end_time - start_time) / num_iterations
    
    ## FPS is the inverse of the average time per pass
    fps = 1.0 / avg_time_per_pass  ## Frames per second
    
    ## Debug: print FPS
    print(f"FPS: {fps}")
  except (ImportError, Exception) as e:
    fps = 0
    print(f"Error calculating FPS: {e}")
  return fps, avg_time_per_pass


def calc_flops(model:torch.nn.Module, input_size:tuple, verbose=False):
  """FLOPS Calculation using profiling."""
  from thop import profile

  img = check_forward_pass(model, input_size)
  try:
    stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
    img = torch.zeros((1, *input_size), device=next(model.parameters()).device)  ## Create dummy input
    total_flops = profile(deepcopy(model), inputs=(img,), verbose=verbose)[0]  ## Total FLOPS for forward pass

    ## Debug: print the total FLOPS
    print(f"Total FLOPS: {total_flops}")
    ## Convert FLOPS to GFLOPS and TOPS
    gflops = total_flops / 1e9  # GFLOPS
    tops = total_flops / 1e12  # TOPS
  except (ImportError, Exception) as e:
    gflops = tops = total_flops = 0
    print(f"Error calculating FLOPS: {e}")

  return total_flops,gflops,tops


def model_perfstats(
  model: torch.nn.Module,
  input_size: tuple,
  device: torch.device,
  verbose: False,
  num_iterations: int=100,
  weights_path=None,
  dnnarch=None,
  num_class=None,
):
  """Generate model information including weight file size, FLOPS, and system environment details."""
  
  ## Calculate model weight file size in MB
  model_weight_filesize_in_mb = 0.0
  if weights_path:
    model_weight_filesize_in_mb = os.path.getsize(weights_path) / (1024 ** 2)
  
  model.eval()

  total_flops, gflops, tops = calc_flops(model, input_size, verbose)
  fps, avg_time_per_pass = calc_fps(model, input_size, num_iterations=num_iterations)
  ## Compile results into a dictionary
  perfstats = {
    "dnnarch": dnnarch,
    "num_class": num_class,
    "model_weight_filesize_in_mb": round(model_weight_filesize_in_mb, 2),
    "fps": round(fps,1),
    "avg_time_per_pass": avg_time_per_pass,
    "flops": total_flops,
    "gflops": gflops,
    "tops": tops,
  }
  return perfstats


def model_summary(
  model: torch.nn.Module,
  input_size: tuple,
  device: torch.device,
  verbose: False,
  num_iterations: int=100,
  weights_path=None,
  dnnarch=None,
  num_class=None,
  depth: int=10,
):
  """Generates key_stats
  - **Model Parameters**:
    - `total_params`: Total number of parameters in the model.
    - `trainable_params`: Number of parameters that are trainable.
    - `non_trainable_params`: Number of parameters that aren’t trainable.

  - **Model Size (in MB)**:
    - `total_mult_adds_m`: Total multiply-add operations (MACs) in millions (M).
    - `input_size_mb`: Estimated size of the input tensor.
    - `forward_backward_pass_size_mb`: Size needed for forward and backward passes.
    - `params_size_mb`: Size of the model’s parameters.
    - `estimated_total_size_mb`: Total estimated size (sum of input, forward-backward pass, and parameters).

  - **Model Details**:
    - `dnnarch`: Specifies the deep neural network architecture, here `MobileNetV2`.
    - `input_shape` and `output_shape`: Input and output tensor shapes; input is a standard image size for MobileNetV2 `(1, 3, 224, 224)`, and output shape `[1, 22]` might indicate 22 possible classes for classification.
    - `total_layers`: Total layers in the model.

  - **Layer Summary**:
    - The counts of each type of layer or component within the model. For instance, `MobileNetV2` is instantiated once, `Conv2d` layers appear 52 times, etc.

  This breakdown is useful for understanding model complexity and resource needs. This information helps in fine-tuning and optimizing the model for deployment.
  """
  from torchinfo import summary

  ## Ensure the model is in evaluation mode and on the correct device
  model.eval()
  model.to(device)

  ## Generate the model summary
  summary_info = summary(
    model,
    input_size=(1, *input_size),
    device=device,
    col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds", "trainable"],
    depth=depth,
    verbose=verbose,
  )

  ## Parameter counts
  n_p = sum(x.numel() for x in model.parameters())  ## Total number of parameters
  n_g = sum(x.numel() for x in model.parameters() if x.requires_grad)  ## Trainable parameters

  ## System environment information
  device = "cuda" if torch.cuda.is_available() else "cpu"
  torch_version = torch.__version__
  cuda_version = torch.version.cuda if torch.cuda.is_available() else "N/A"
  python_version = platform.python_version()

  ## Calculate model weight file size in MB
  model_weight_filesize_in_mb = 0.0
  if weights_path:
    model_weight_filesize_in_mb = os.path.getsize(weights_path) / (1024 ** 2)

  total_flops, gflops, tops = calc_flops(model, input_size, verbose)
  fps, avg_time_per_pass = calc_fps(model, input_size, num_iterations=num_iterations)
 
  ## Parse key statistics from summary info for CSV and JSON
  key_stats = {
    "total_params": summary_info.total_params,
    "trainable_params": summary_info.trainable_params,
    "non_trainable_params": summary_info.total_params - summary_info.trainable_params,
    "n_p": n_p,
    "n_g": n_g,
    "total_mult_adds_in_millions": round(summary_info.total_mult_adds / 1e6, 2),
    "input_size_mb": round(summary_info.total_input / (1024 ** 2), 2),
    "forward_backward_pass_size_mb": round(summary_info.total_output_bytes / (1024 ** 2), 2),
    "params_size_mb": round(summary_info.total_param_bytes / (1024 ** 2), 2),
    "estimated_total_size_mb": round((summary_info.total_input + summary_info.total_output_bytes + summary_info.total_param_bytes) / (1024 ** 2), 2),
    "network": model.__class__.__name__,
    "dnnarch": dnnarch,
    "num_class": num_class,
    "input_shape": str(summary_info.input_size[0] if summary_info.input_size else None),
    "output_shape": str(summary_info.summary_list[-1].output_size if summary_info.summary_list else None),
    "total_layers": len(summary_info.summary_list),
    ## System environment information
    "device": device,
    "torch_version": torch_version,
    "cuda_version": cuda_version,
    "python_version": python_version,
    ## performance stats
    "model_weight_filesize_in_mb": round(model_weight_filesize_in_mb, 2),
    "fps": round(fps,1),
    "num_iterations_for_fps": num_iterations,
    "avg_time_per_pass": avg_time_per_pass,
    "flops": total_flops,
    "gflops": gflops,
    "tops": tops,
    "creator": __author__,
  }

  ## Count layer types
  layer_counts = {"layers": {}}
  for layer in summary_info.summary_list:
    layer_type = layer.class_name
    layer_counts["layers"][layer_type] = layer_counts["layers"].get(layer_type, 0) + 1
  key_stats.update(layer_counts)
  return key_stats, summary_info


def save_metrics(to_path, metrics):
  """Save evaluation metrics to JSON."""
  metrics_path = os.path.join(to_path, "classification_metrics.json")
  serialized_metrics_data = encoders.numpy_to_json(metrics)
  with open(metrics_path, 'w') as f:
    f.write(serialized_metrics_data)
  print(f"Metrics saved to {metrics_path}")
  return metrics_path


def save_model_summary(to_path, key_stats, summary_info):
  """write model summary info and key stats to txt, csv and json files.
  """
  ## Write the summary to a TXT file
  txt_file_path = os.path.join(to_path, "modelsummary.txt")
  with open(txt_file_path, "w") as txt_file:
    txt_file.write(str(summary_info))

  ## Save key statistics as CSV
  csv_file_path = os.path.join(to_path, "modelsummary.csv")
  pd.DataFrame([key_stats]).to_csv(csv_file_path, index=False)

  ## Save key statistics as JSON
  json_file_path = os.path.join(to_path, "modelsummary.json")
  with open(json_file_path, "w") as json_file:
    json.dump(key_stats, json_file, indent=2)

  print(f"Model summary saved to: {txt_file_path}, {csv_file_path}, and {json_file_path}")
  return txt_file_path, csv_file_path, json_file_path


def save_model_perfstats(to_path, perfstats):
  """write model summary info and key stats to txt, csv and json files.
  """

  ## Save key statistics as CSV
  csv_file_path = os.path.join(to_path, "modelperfstats.csv")
  pd.DataFrame([perfstats]).to_csv(csv_file_path, index=False)

  ## Save key statistics as JSON
  json_file_path = os.path.join(to_path, "modelperfstats.json")
  with open(json_file_path, "w") as json_file:
    json.dump(perfstats, json_file, indent=2)

  print(f"Model perfstats saved to: {csv_file_path}, and {json_file_path}")
  return csv_file_path, json_file_path


def save_model_architecture(model, save_dir, model_name):
  arch_file = os.path.join(save_dir, f"{model_name}-model.txt")
  with open(arch_file, 'w') as f:
    f.write(str(model))
  print(f"Model architecture saved to {arch_file}")
  return arch_file


def save_report(report_data, save_dir):
  report_path = os.path.join(save_dir, "classification_report.json")
  serialized_report_data = encoders.numpy_to_json(report_data)
  with open(report_path, 'w') as f:
    f.write(serialized_report_data)
  print(f"Classification report saved to {report_path}")
  return report_path


def extract_stat_value(pattern: str, text: str):
  """Helper function to extract specific stats using regex."""
  match = re.search(pattern, text)
  return match.group(1).replace(",", "") if match else "N/A"


def count_layer_type(layer_pattern: str, text: str):
  """Helper function to count occurrences of specific layer types."""
  return len(re.findall(layer_pattern, text))
