import os
import matplotlib.pyplot as plt
import numpy as np


import pickle
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from copy import deepcopy


import logging
from time import time, strftime, localtime


from utils import MyDataset, requires_grad, update_ema, create_logger

import numpy as np 
from time import time, strftime, localtime



import matplotlib.pyplot as plt
from collections import OrderedDict

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import MinMaxScaler, RobustScaler

from inception import Inception, InceptionBlock


from utils import draw_figure, show_figure, if_nan
from collections import Counter

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return torch.flatten(x, start_dim=1)
def find_model(model_name):


    assert os.path.isfile(model_name), f'Could not find checkpoint at {model_name}'
    checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
    # if "ema" in checkpoint:  # supports checkpoints from train.py
    #     checkpoint = checkpoint["ema"]
    
    return checkpoint['model']
model = nn.Sequential(                  # input_size = （B，C，L）
                    InceptionBlock(
                        in_channels=1, 
                        n_filters=32, 
                        kernel_sizes=[5, 11, 23],
                        bottleneck_channels=32,
                        use_residual=True,
                        activation=nn.ReLU()
                    ),
                    InceptionBlock(
                        in_channels=32*4, 
                        n_filters=32, 
                        kernel_sizes=[5, 11, 23],
                        bottleneck_channels=32,
                        use_residual=True,
                        activation=nn.ReLU()
                    ),
                    nn.AdaptiveAvgPool1d(output_size=1),
                    Flatten(),
                    nn.Linear(in_features=4*32*1, out_features=4)
        ).cuda()

folder_path = './model'
files = os.listdir(folder_path)
pt_files = [file for file in files if file.endswith('.pt')]

pt_files = sorted(pt_files)

test_dataset = torch.load(f'./dataset/test_dataset.pth')

test_loader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)
results = {}

for model_name in pt_files:
	model_path = f"./model/{model_name}"
	state_dict = find_model(model_path)
	model.load_state_dict(state_dict)

	model.eval()

	total_num = 0
	correct_num = 0
	with torch.no_grad():
	    for x, y in test_loader:
	        x = x.to('cuda')
	        y = y.to('cuda').long()

	        out = model(x)
	        _, predicted = torch.max(out, 1)

	        correct_num += (predicted == y).sum().item()
	        total_num += y.shape[0] 


	correct_rate = correct_num/total_num

	results[f'{model_name}'] = correct_rate
	print(f'{model_name}:{correct_rate}')
