import os
import datetime
import torch
import pandas as pd
import datetime
from torchvision import models
from datasets import *
from tqdm import tqdm
from torchmetrics.classification import Accuracy, F1Score #type:ignore

import warnings
warnings.filterwarnings("ignore") 

# Set the enviroment variable for pretrained docker purposes
os.environ['TORCH_HOME'] = 'pytorch_models/'

if __name__ == '__main__':
    csv_folder = './ablation_logs'
    checkpoint_folder = './ablation_checkpoints'
    num_classes = 10

    sorted_csv_files = sorted(os.listdir(csv_folder))
    sorted_pth_files = sorted(os.listdir(checkpoint_folder))

    for index in range(len(sorted_csv_files)):
        # Combine the file paths
        csv_file = os.path.join(csv_folder, sorted_csv_files[index])
        pth_file = os.path.join(checkpoint_folder, sorted_pth_files[index])

        # Determine which dataset to use
        # print(csv_file)
        if 'fashionmnist' in csv_file:
            full_dataset = CustomFashionMNIST(image_dir = '/data/progressive_data_dropout/fashionmnist')
        elif 'mnist' in csv_file:
            full_dataset = CustomMNIST(image_dir = '/data/progressive_data_dropout/mnist')
        elif 'svhn' in csv_file:
            full_dataset = CustomSVHN(image_dir = '/data/progressive_data_dropout/svhn')
        elif 'cifar10' in csv_file:
            full_dataset = CustomCIFAR10(image_dir = '/data/progressive_data_dropout/cifar10')
        else:
            print('Incorrect experiment')
            continue

        # Setup the testing dataset
        full_dataset.setup(stage = 'test')
        test_dataset = full_dataset.test

        # Get the base model
        base_model = models.resnet18(pretrained = False, num_classes = num_classes)

        # Load in the model weights
        print('Loading:', pth_file)
        checkpoint = torch.load(pth_file)
        base_model.load_state_dict(checkpoint['model_state_dict'])
        last_epoch = checkpoint['epoch']

        # Switch model to eval
        base_model.eval()

        # Send model to gpu
        base_model.to(device='cuda:0')

        # Load in the dataloader
        test_dataloader = torch.utils.data.DataLoader(
            dataset=test_dataset,
            shuffle=False,
            batch_size=32,
            num_workers=8,
            pin_memory=False
        )
        
        # Loss function
        loss_fn = torch.nn.CrossEntropyLoss()

        # Metric implementation
        testing_accuracy = Accuracy()
        testing_f1score = F1Score()
        per_class_testing_f1score = F1Score(num_classes = num_classes, average = None)

        # Send metric to the gpu
        testing_accuracy = testing_accuracy.to(device='cuda:0')
        testing_f1score = testing_f1score.to(device='cuda:0')
        per_class_testing_f1score = per_class_testing_f1score.to(device='cuda:0')

        # Progress bar (Needs reset every epoch)
        current_loss_total = 0.0
        # progress_bar = tqdm(test_dataloader)
        
        # Loop through the data
        for current_batch_idx, (images, labels) in enumerate(test_dataloader):
            # Send images and labels to the GPU
            images, labels = images.to(device='cuda:0'), labels.to(device='cuda:0')

            # Get the model output
            outputs = base_model(images)

            # Compute the loss
            loss = loss_fn(outputs, labels)

            # Update the metrics
            testing_accuracy.update(outputs, labels)
            testing_f1score.update(outputs, labels)
            per_class_testing_f1score.update(outputs, labels)

            # Keep track of the running loss
            current_loss_total += loss.item()

            # Average the loss across all the batches
            current_loss = current_loss_total / (current_batch_idx + 1)

            # Update progress bar
            # progress_bar.set_description('Loss: {:.2f}'.format(current_loss), refresh = True)

        # Close the progress bars
        # progress_bar.close()

        # Compute the metrics
        total_accuracy = testing_accuracy.compute().item()
        total_f1score = testing_f1score.compute().item()
        total_per_class_testing_f1score = per_class_testing_f1score.compute()
        
        
        # print(f'Test Loss: {current_loss:.2f}')
        # print(f'Test Accuracy: {total_accuracy:.2f}')
        # print(f'Test F1Score: {total_f1score:.2f}')

        # Read csv into dataframe
        training_df = pd.read_csv(csv_file)
        print(last_epoch)
        
        # Cut dataframe off based on checkpoint
        training_df = training_df.iloc[:last_epoch+1,:]
        # print(training_df)
        
        # Calculate stats
        training_time = np.sum(np.array(training_df['Epoch Time'].values))
        training_time_converted = str(datetime.timedelta(seconds = round(training_time)))
        training_datapoints = np.sum(np.array(training_df['training_datapoints'].values))

        # Print out metrics
        print(csv_file)
        print(f'Test Accuracy: {total_accuracy:.2f}')
        print(f'Rounds of Training: {len(training_df.index)}')
        print(f'Training Time: {training_time_converted}')
        print(f'Training Datapoints: {training_datapoints:.2f}')
        dataset_name = ((csv_file.split('/')[2]).split('.')[0]).split('_')[0]
        if '0.85' in csv_file:
            print(f'{dataset_name} & 0.85 & {total_accuracy:.2f} & {training_time_converted} & {len(training_df.index)} & {int(training_datapoints)}  \\\\')
        elif '0.95' in csv_file:
            print(f'{dataset_name} & 0.95 & {total_accuracy:.2f} & {training_time_converted} & {len(training_df.index)} & {int(training_datapoints)}  \\\\')
        else:
            print(f'{dataset_name} & N/A & {total_accuracy:.2f} & {training_time_converted} & {len(training_df.index)} & {int(training_datapoints)}  \\\\')
        print()
    
    # # Read csv file into dataframe
    # training_df = pd.read_csv(csv_path)
    # training_time = np.sum(np.array(training_df['training_time'].values))
    # training_time_converted = str(datetime.timedelta(seconds = round(training_time)))
    # training_datapoints = np.sum(np.array(training_df['training_datapoints'].values))
    # training_datadropout_time = np.sum(np.array(training_df['data_drop_time'].values))
    # training_datadropout_time_converted = str(datetime.timedelta(seconds = round(training_datadropout_time)))

    

    # # print(class_indicies)