from pathlib import Path
import csv
import os
from typing import List

import itertools
import numpy as np
import torch
from ase import io
from ase.calculators.singlepoint import SinglePointCalculator
from torch_geometric.loader import DataLoader
from tqdm import tqdm

import hienet._keys as KEY
from hienet._const import LossType
from hienet.train.dataload import graph_build
from hienet.train.dataset import AtomGraphDataset
from hienet.util import (
    model_from_checkpoint,
    postprocess_output,
    to_atom_graph_list,
)
import torch
import time
import argparse

os.environ["CUDA_VISIBLE_DEVICES"] = "6"

def load_data(datas: str, cutoff, type_map):
    full_dataset = None
    for data in datas:
        with open(data, 'rb') as f:
            dataset = torch.load(f)
        if full_dataset is None:
            full_dataset = dataset
        else:
            full_dataset.augment(dataset)
    if full_dataset.cutoff != cutoff:
        raise ValueError(f'cutoff mismatch: {full_dataset.cutoff} != {cutoff}')

    return full_dataset

def benchmark_throughput(model, fnames, device='cuda', batch_size=1, num_iterations=100):
    model.eval()
    model.to(device)
    model.set_is_batch_data(True)

    inference_set = load_data(fnames, model.cutoff, model.type_map)
    inference_set.x_to_one_hot_idx(model.type_map)
    infer_list, _ = inference_set.separate_info()

    loader = DataLoader(infer_list, batch_size=batch_size, shuffle=False)

    for batch in itertools.islice(loader, 10):
        batch = batch.to(device, non_blocking=True)
        batch.pos = batch.pos.detach().requires_grad_(True)
        _ = model(batch)
    torch.cuda.synchronize()

    total_samples = 0
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for i, batch in enumerate(loader):
        if i >= num_iterations:
            break
        batch = batch.to(device, non_blocking=True)
        batch.pos = batch.pos.detach().requires_grad_(True)
        _ = model(batch)
        total_samples += batch.num_graphs  
    end_event.record()

    torch.cuda.synchronize()
    gpu_seconds = start_event.elapsed_time(end_event) / 1000 

    throughput = total_samples / gpu_seconds
    print(f"Processed {total_samples} samples in {gpu_seconds:.3f} seconds")

    return throughput

def main():
   parser = argparse.ArgumentParser()
   parser.add_argument("--checkpoint", type=str, required=True)
   parser.add_argument("--data", type=str, required=True)
   parser.add_argument("--batch_size", type=int, default=1)
   parser.add_argument("--num_iterations", type=int, default=100)
   args = parser.parse_args()

   model, config = model_from_checkpoint(args.checkpoint)
   throughput = benchmark_throughput(
       model=model,
       fnames=[args.data],
       batch_size=args.batch_size,
       num_iterations=args.num_iterations
   )
   params = sum(p.numel() for p in model.parameters())

   print(f"Parameters: {params:,}")
   print(f"Throughput: {throughput:.1f} samples/sec")

if __name__ == "__main__":
   main()
