import torch.nn as nn
import torch

def find_layers(module, layers=[nn.Linear], name=''):
    # 递归查找模块中某种类型的层，并返回字典
    """
    Recursively find the layers of a certain type in a module.
    递归查找模块中某种类型的层
    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

import time
import csv
import json
import statistics
from collections import defaultdict

class Timer:
    def __init__(self):
        self._start_times = {}
        self.records = defaultdict(list)  # {name: [durations]}

    def start(self, name: str):
        if name in self._start_times:
            raise RuntimeError(f"Timer for '{name}' already started.")
        torch.cuda.synchronize()  
        self._start_times[name] = time.time()

    def end(self, name: str):
        if name not in self._start_times:
            raise RuntimeError(f"Timer for '{name}' was not started.")
        torch.cuda.synchronize()  
        duration = time.time() - self._start_times.pop(name)
        self.records[name].append(duration)

    def mean(self, name: str):
        times = self.records.get(name, [])
        return sum(times) / len(times) if times else 0.0

    def summary(self):
        result = {}
        for name, durations in self.records.items():
            result[name] = {
                "count": len(durations),
                "mean": sum(durations) / len(durations),
                "min": min(durations),
                "max": max(durations),
                "std": statistics.stdev(durations) if len(durations) > 1 else 0.0
            }
        return result

    def save(self, filepath: str):
        with open(filepath, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["name", "duration_seconds"])
            for name, durations in self.records.items():
                for t in durations:
                    writer.writerow([name, t])

    def save_summary(self, filepath: str, fmt="csv"):
        summary = self.summary()
        if fmt == "csv":
            with open(filepath, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["name", "count", "mean", "min", "max", "std"])
                for name, stats in summary.items():
                    writer.writerow([name, stats["count"], stats["mean"], stats["min"], stats["max"], stats["std"]])
        elif fmt == "json":
            with open(filepath, "w") as f:
                json.dump(summary, f, indent=2)
        else:
            raise ValueError(f"Unsupported format: {fmt}. Use 'csv' or 'json'.")

    def reset(self):
        self._start_times.clear()
        self.records.clear()

    def structure_print(self):
        sum_time = 0
        for name, durations in self.records.items():
            mean = sum(durations) / len(durations)
            sum_time += mean
            print(f"{name}: {len(durations)} runs, mean: {mean * 1000} ms")
        print(f"Total time: {sum_time * 1000} ms")

    def time_merge(self, name_list:list, new_name:str):
        for name in name_list:
            if name in self.records:
                self.records[new_name] += self.records[name]
                del self.records[name]
            else:
                print(f"Warning: {name} not found in timer.records.")
        
# timer = Timer()