# -*- coding: utf-8 -*-
"""
Created on Tue Jan  7 11:48:11 2025

@author: User
"""

import os
dir_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(dir_path)


from torch.utils import data

import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets
import torchvision
from lib_preprocessing import *
import matplotlib.pyplot as plt
import copy 


class Dataset_v2(data.Dataset):
    # Characterizes a dataset for PyTorch'
    
    def __init__(self, inputs, labels, transform=None):
        # 'Initialization'
        self.labels = labels
        # self.list_IDs = list_IDs
        self.inputs = inputs

        self.transform = transform
        

    def __len__(self):
        # 'Denotes the total number of samples'
        return self.inputs.shape[0]

    def __getitem__(self, index):
        'Generates one sample of data'
        
        img = self.inputs[index]

        if self.transform is not None:
            img = self.transform(img)

        y = int(self.labels[index])

        return img, y, index
    

class ImageFolderWithIndex(datasets.ImageFolder):
    """Custom dataset that includes image index as output. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method that dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        img, y = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        return img, y,index
    

    
def genloaders_fromfolder(train_dir, test_dir, loader_params):
    
    imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
     
    train_data =  ImageFolderWithIndex(train_dir,transform = loader_params.transform)
    test_data =  ImageFolderWithIndex(test_dir,transform = loader_params.transform)
    
    trainloader = torch.utils.data.DataLoader(train_data, batch_size= loader_params.batch_size,
                                          shuffle=True,generator=torch.Generator(device='cuda'), num_workers=0)
    testloader = torch.utils.data.DataLoader(train_data, batch_size= loader_params.batch_size,
                                          shuffle=True,generator=torch.Generator(device='cuda'), num_workers=0)
    
    IG_trainloader = torch.utils.data.DataLoader(train_data, batch_size=loader_params.IG_batch_size,
                                          shuffle=True,generator=torch.Generator(device='cuda'), num_workers=0)
    
    return trainloader, testloader, IG_trainloader
    
    

def genloaders(X_train, y_train, X_test, y_test, loader_params):
    
    if loader_params['convert_to_torch']:
        X_train = torch.from_numpy(X_train).float().cuda()
        y_train = torch.from_numpy(y_train).cuda()
        
        X_test = torch.from_numpy(X_test).float().cuda()
        y_test = torch.from_numpy(y_test).cuda()
    
    if loader_params['training_size'] != 'full':
        X_train = X_train[0:loader_params['training_size']]
        y_train = y_train[0:loader_params['training_size']]
        
    
    if loader_params['conversion'] == 'rank':
        X_train,params = rank_convert_data(X_train)
        X_test = rank_convert_data(X_test,params)
    elif loader_params['conversion'] == 'uniform':
        # X = uniform_convert_data(X)
        X_train,params = uniform_convert_data(X_train)
        X_test = uniform_convert_data(X_test,params)
    elif loader_params['conversion'] == 'uniform_scale':
        # X = uniform_convert_data(X)
        X_train,params = uniform_scale_convert_data(X_train)
        X_test = uniform_scale_convert_data(X_test,params)
    elif loader_params['conversion'] == 'normalize':
        # X = normalized_convert_data(X)
        X_train,params = normalized_convert_data(X_train)
        X_test = normalized_convert_data(X_test,params)
        
    if loader_params['add_singleton']:
        X_train = X_train.unsqueeze(2).unsqueeze(3)
        X_test = X_test.unsqueeze(2).unsqueeze(3)
        
        
    my_dataset = Dataset_v2(X_train, y_train,loader_params['transform'])
    my_dataset_test = Dataset_v2(X_test, y_test,loader_params['transform'])

    trainloader = torch.utils.data.DataLoader(my_dataset, batch_size= loader_params['batch_size'],
                                          shuffle=True,generator=torch.Generator(device='cuda'), num_workers=0)
    testloader = torch.utils.data.DataLoader(my_dataset_test, batch_size=loader_params['batch_size'],
                                          shuffle=False,generator=torch.Generator(device='cuda'),num_workers=0)
    
    IG_trainloader = torch.utils.data.DataLoader(my_dataset, batch_size=loader_params['IG_batch_size'],
                                          shuffle=True,generator=torch.Generator(device='cuda'), num_workers=0)
    
    return trainloader, testloader, IG_trainloader



def gen_pruned_loaders(X_train, y_train, X_test, y_test, loader_params):
    
    if loader_params['convert_to_torch']:
        X_train = torch.from_numpy(X_train).float().cuda()
        y_train = torch.from_numpy(y_train).cuda()
        
    
    if loader_params['training_size'] != 'full':
        X_train = X_train[0:loader_params['training_size']]
        y_train = y_train[0:loader_params['training_size']]
        
    if loader_params['train_indices'] != 'all':
        X_train = X_train[loader_params['train_indices']]
        y_train = y_train[loader_params['train_indices']]
    
    
    
    if loader_params['conversion'] == 'rank':
        X_train,params = rank_convert_data(X_train)
        X_test = rank_convert_data(X_test,params)
    elif loader_params['conversion'] == 'uniform':
        # X = uniform_convert_data(X)
        X_train,params = uniform_convert_data(X_train)
        X_test = uniform_convert_data(X_test,params)
    elif loader_params['conversion'] == 'uniform_scale':
        # X = uniform_convert_data(X)
        X_train,params = uniform_scale_convert_data(X_train)
        X_test = uniform_scale_convert_data(X_test,params)
    elif loader_params['conversion'] == 'normalize':
        # X = normalized_convert_data(X)
        X_train,params = normalized_convert_data(X_train)
        X_test = normalized_convert_data(X_test,params)
        
    if loader_params['add_singleton']:
        X_train = X_train.unsqueeze(2).unsqueeze(3)
        X_test = X_test.unsqueeze(2).unsqueeze(3)
        
        
    my_dataset = Dataset_v2(X_train, y_train,loader_params['transform'])
    my_dataset_test = Dataset_v2(X_test, y_test,loader_params['transform'])

    trainloader = torch.utils.data.DataLoader(my_dataset, batch_size= loader_params['batch_size'],
                                          shuffle=True,generator=torch.Generator(device='cuda'), num_workers=0)
    testloader = torch.utils.data.DataLoader(my_dataset_test, batch_size=loader_params['batch_size'],
                                          shuffle=False,generator=torch.Generator(device='cuda'),num_workers=0)
    

    
    return trainloader,testloader



def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



def test_model(model, testloader):
    model = model.eval()
    correct = torch.tensor(0)
    dataiter = iter(testloader)

    with torch.no_grad():
        for i, data in enumerate(testloader, 0):
            # get the inputs
            inputs, labels, indices = data
            inputs = inputs  
            all_outs = model(inputs)
            predicted = torch.argmax(all_outs,1)
            correct = correct + torch.sum((predicted == labels).float())
    
    accuracy = float(correct) / float(len(testloader.dataset.labels))
    
    return accuracy



def train_model_general(model,trainloader, train_params):
    
    print("Total Model Params: ", count_parameters(model))

    model = model.cuda()
    model = model.train()
    
    
    if train_params['optimizer'] == 'SGD':
            optimizer = optim.SGD(model.parameters(), lr=train_params['init_rate'], momentum=0.9, weight_decay = train_params['weight_decay'])
    elif train_params['optimizer'] == 'Adam':
            optimizer = optim.Adam(model.parameters(), lr=train_params['init_rate'], weight_decay = train_params['weight_decay'])
    elif train_params['optimizer'] == 'AdamW':
            optimizer = optim.AdamW(model.parameters(), lr=train_params['init_rate'], weight_decay = train_params['weight_decay'])
            
    scheduler = optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
    
    if train_params['scheduler']['name'] == 'StepLR': 
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=train_params['scheduler']['step_size'], gamma= train_params['scheduler']['gamma'])
    elif train_params['scheduler']['name'] == 'MultiStepLR': 
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=train_params['scheduler']['milestones'], gamma= train_params['scheduler']['gamma'])
    elif train_params['scheduler']['name'] == 'CyclicLR':
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, train_params['init_rate'], train_params['scheduler']['max_lr'], 
                train_params['scheduler']['step_size'], 
                step_size_down=train_params['scheduler']['step_size'], 
                mode='triangular', gamma=train_params['scheduler']['gamma'])

    init_epoch = 0
    all_train_losses = []
    train_loss_min = 9999
    
    if train_params['criterion'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss()
    elif train_params['criterion'] == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss() 
    elif train_params['criterion'] == 'MSELoss':
        criterion = nn.MSELoss() 
    
    flag = 0 
    
    
    for epoch in range(train_params['total_epochs']):

        # batchloss_diffs = np.empty(0)
        # trainloss_diffs = np.empty(0)
        # loss_sums = 0
        # loss_change = []
        s = time.time()
        
        if train_params['disp_epoch'] == True: 
            print('epoch: ' + str(epoch))
        
            
        train_loss = []
        loss_weights = [] 
        
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels, indices = data
            # inputs = inputs.cuda()
            # labels = labels.cuda()

            optimizer.zero_grad()
            allouts = model(inputs)
            
            loss = criterion(allouts, labels.long()) 
            # print(loss.item())
            loss.backward()
            train_loss.append(loss.item())
            
            loss_weights.append(len(labels))
            optimizer.step()
            
            
        
        scheduler.step()
        all_train_losses.append(np.average(np.array(train_loss),weights=np.array(loss_weights)))
        if train_params['disp_loss_epoch'] == True:
            print("Training Loss:", all_train_losses[-1])
        
        if train_params['disp_time_per_epoch'] == True and flag == 0: 
            print("Time for one epoch:",time.time()-s)
            flag = 1
            
        
        
    if train_params['disp_loss_final'] == True:
        print(all_train_losses[-1])
      
    if train_params['disp_accuracy_final'] == True:
        accuracy = test_model(model,trainloader)
    
    model = model.eval()
    
    return model,all_train_losses









