import torch
import torch.nn as nn
from tqdm import tqdm
from gsam import enable_running_stats, disable_running_stats
import torchvision.datasets as datasets

from torchvision.transforms import v2
import pickle




def train_one_epoch(model, loader, optimizer, scheduler, preprocess_op, train = True, device = 'cuda:0', aug = nn.Identity(), temp = 1.0, soft_labels = False, teacher_models = [], eval_mode = None, freeze_bn = False):
    if eval_mode is None:
        eval_mode = not train
        
    if not eval_mode:
        model.train()
    else:
        model.eval()
        
    if freeze_bn:
        model.eval()
    
    total_loss = 0
    total_correct = 0
    total = 0
    steps_per_epoch = 0
        
    for x, y in tqdm(loader):
    # for x in tqdm(loader):
    #     print(x)
        x_device = x.to(device)
        y_device = y.to(device)


        # zero the parameter gradients
        if train:
            optimizer.zero_grad()
            x_device = aug(x_device)

        with torch.set_grad_enabled(train):
            outputs = model(x_device)
            
            _, preds = torch.max(outputs, 1)
            if soft_labels:
                # print("SLKSLS")
                # loss = temp**2 * torch.nn.functional.kl_div(torch.nn.functional.log_softmax(outputs/temp), torch.nn.functional.log_softmax(y_device/temp), reduction = 'batchmean', log_target=True)
                # log_

                if len(teacher_models) > 0:
                    with torch.set_grad_enabled(False):
                        y_device = torch.zeros_like(y_device)
                        for teacher_model in teacher_models:
                            y_device += teacher_model(x_device)/len(teacher_models)

                soft_targets = nn.functional.softmax(y_device / temp, dim=-1)
                soft_prob = nn.functional.log_softmax(outputs / temp, dim=-1)
                log_soft_targets = nn.functional.log_softmax(y_device / temp, dim=-1)

                # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
                loss = -torch.sum(soft_targets * soft_prob) / soft_prob.shape[0] * (temp**2)
                loss += torch.sum(soft_targets * log_soft_targets) / soft_prob.shape[0] * (temp**2)

            else:
                if len(y_device.shape) == 2:
                    loss = torch.nn.functional.cross_entropy(outputs/temp, torch.argmax(y_device,1), reduction = 'mean')
                    
                    # loss = torch.nn.functional.mse_loss(outputs/temp, y_device, reduction = 'mean')
                else:
                    loss = torch.nn.functional.cross_entropy(outputs/temp, y_device, reduction = 'mean')


            # backward + optimize only if in training phase
            if train:
                loss.backward()
                optimizer.step()

        # statistics
        total_loss += loss.item() * x.shape[0]
        total_correct += torch.sum(preds == torch.argmax(y_device, axis = 1))

        # total_correct += torch.sum(preds == y_device)
        total += x.shape[0]
        steps_per_epoch += 1

    
    if train:
        scheduler.step()
        
    return total_loss/total, total_correct/total, total



def train_one_epoch_adv(model, loader, optimizer, scheduler, preprocess_op, train = True, device = 'cuda:0', aug = nn.Identity(), temp = 1.0, soft_labels = False, teacher_models = [], eval_mode = None, freeze_bn = False):
    if eval_mode is None:
        eval_mode = not train
        
    if not eval_mode:
        model.train()
    else:
        model.eval()
        
    if freeze_bn:
        model.eval()
    
    total_loss = 0
    total_correct = 0
    total = 0
    steps_per_epoch = 0
        
    for x, y in tqdm(loader):
        x_device = x.to(device)
        y_device = y.to(device)



        # zero the parameter gradients
        if train:
            optimizer.zero_grad()
            x_device = aug(x_device)

        # x_device = x_device.clone()
        x_device_orig = x_device.clone()

        x_device.requires_grad = True
        for f in range(4):
            with torch.set_grad_enabled(train):
                outputs = model(x_device)
                
                _, preds = torch.max(outputs, 1)

                
                if len(y_device.shape) == 2:
                    loss = torch.nn.functional.cross_entropy(outputs/temp, torch.argmax(y_device,1), reduction = 'mean')
                else:
                    loss = torch.nn.functional.cross_entropy(outputs/temp, y_device, reduction = 'mean')


                # backward + optimize only if in training phase
                if train:
                    loss.backward()
                # print(f'before: {loss}')
                x_device = (x_device + 0.02 * x_device.grad.data.sign()).clone().detach()

                x_device = (x_device_orig + torch.clip(x_device - x_device_orig, -0.04, 0.04)).clone().detach()
                x_device.requires_grad = True


        optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            outputs = model(x_device)
            
            _, preds = torch.max(outputs, 1)

            
            if len(y_device.shape) == 2:
                loss = torch.nn.functional.cross_entropy(outputs/temp, torch.argmax(y_device,1), reduction = 'mean')
            else:
                loss = torch.nn.functional.cross_entropy(outputs/temp, y_device, reduction = 'mean')


            # backward + optimize only if in training phase
            if train:
                loss.backward()
                optimizer.step()

            # print(f'after: {loss}')

        # statistics
        total_loss += loss.item() * x.shape[0]
        total_correct += torch.sum(preds == torch.argmax(y_device, axis = 1))

        # total_correct += torch.sum(preds == y_device)
        total += x.shape[0]
        steps_per_epoch += 1

    
    if train:
        scheduler.step()
        
    return total_loss/total, total_correct/total, total

def train_one_epoch_gsam(model, loader, optimizer, scheduler, preprocess_op, train = True, device = 'cuda:0', aug = nn.Identity(), temp = 1.0, soft_labels = False, teacher_models = [], eval_mode = None, freeze_bn = False):
    if eval_mode is None:
        eval_mode = not train
        
    if not eval_mode:
        model.train()
    else:
        model.eval()
        
    if freeze_bn:
        model.eval()
    
    total_loss = 0
    total_correct = 0
    total = 0
    steps_per_epoch = 0
    
    mixup = v2.MixUp()
        
    for x, y in tqdm(loader):
        x_device = x.to(device)
        y_device = y.to(device)
        
        # print(y_device)
        
        # x_device, y_device = mixup(x_device, y_device)
        

        # zero the parameter gradients
        if train:
            optimizer.zero_grad()
            x_device = aug(x_device)

        with torch.set_grad_enabled(train):
            def loss_fn(outputs, y_device):
                if soft_labels:
                    soft_targets = nn.functional.softmax(y_device / temp, dim=-1)
                    soft_prob = nn.functional.log_softmax(outputs / temp, dim=-1)

                    # print(torch.max(soft_targets, 1))
                    log_soft_targets = nn.functional.log_softmax(y_device / temp, dim=-1)

                    # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
                    loss = -torch.sum(soft_targets * soft_prob) / soft_prob.shape[0] * (temp**2)
                    loss += torch.sum(soft_targets * log_soft_targets) / soft_prob.shape[0] * (temp**2)
                else:
                    if len(y_device.shape) == 2:
                        loss = torch.nn.functional.cross_entropy(outputs/temp, torch.argmax(y_device,1), reduction = 'mean')
                    else:
                        loss = torch.nn.functional.cross_entropy(outputs/temp, y_device, reduction = 'mean')
                    
                return loss
                    
            optimizer.set_closure(loss_fn, x_device, y_device)
            outputs, loss = optimizer.step()
                        
            _, preds = torch.max(outputs, 1)
            # print(loss)
            
            with torch.no_grad():
                optimizer.update_rho_t()


        # statistics
        total_loss += loss.item() * x.shape[0]
        total_correct += torch.sum(preds == torch.argmax(y_device, axis = 1))

        # total_correct += torch.sum(preds == y_device)
        total += x.shape[0]
        steps_per_epoch += 1

    
    if train:
        scheduler.step()
        
    return total_loss/total, total_correct/total, total



def train_one_epoch_fkd(model, loader, optimizer, scheduler, epoch, train = True, device = 'cuda:0', temp = 1.0, soft_labels = True):
    if train:
        model.train()
    else:
        model.eval()
    
    total_loss = 0
    total_correct = 0
    total = 0
    steps_per_epoch = 0
    
    loader.dataset.set_epoch(epoch)
    for (x, _, _, _), y in tqdm(loader):
        x_device = x.to(device)
        y_device = y.to(device)

        if train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            outputs = model(x_device)
            _, preds = torch.max(outputs, 1)
            if soft_labels:
                soft_targets = nn.functional.softmax(y_device / temp, dim=-1)
                soft_prob = nn.functional.log_softmax(outputs / temp, dim=-1)
                log_soft_targets = nn.functional.log_softmax(y_device / temp, dim=-1)

                # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
                loss = -torch.sum(soft_targets * soft_prob) / soft_prob.shape[0] * (temp**2)
                loss += torch.sum(soft_targets * log_soft_targets) / soft_prob.shape[0] * (temp**2)

            else:
                loss = torch.nn.functional.cross_entropy(outputs/temp, y_device, reduction = 'mean')

            # backward + optimize only if in training phase
            if train:
                loss.backward()
                optimizer.step()

        # statistics
        total_loss += loss.item() * x.shape[0]
        total_correct += torch.sum(preds == torch.argmax(y_device, axis = 1))

        # total_correct += torch.sum(preds == y_device)
        total += x.shape[0]
        steps_per_epoch += 1

    
    if train:
        scheduler.step()
        
    return total_loss/total, total_correct/total, total


def train_one_epoch_fkd_new(model, loader, optimizer, scheduler, epoch, train = True, device = 'cuda:0', temp = 1.0, soft_labels = True):
    if train:
        model.train()
    else:
        model.eval()
    
    total_loss = 0
    total_correct = 0
    total = 0
    steps_per_epoch = 0
    
    loader.dataset.set_epoch(epoch)
    for x, y, _, _ in tqdm(loader):
        x_device = x.to(device)
        y_device = y.to(device)

        if train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            outputs = model(x_device)
            _, preds = torch.max(outputs, 1)
            if soft_labels:
                soft_targets = nn.functional.softmax(y_device / temp, dim=-1)
                soft_prob = nn.functional.log_softmax(outputs / temp, dim=-1)
                log_soft_targets = nn.functional.log_softmax(y_device / temp, dim=-1)

                # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
                loss = -torch.sum(soft_targets * soft_prob) / soft_prob.shape[0] * (temp**2)
                loss += torch.sum(soft_targets * log_soft_targets) / soft_prob.shape[0] * (temp**2)

            else:
                loss = torch.nn.functional.cross_entropy(outputs/temp, y_device, reduction = 'mean')

            # backward + optimize only if in training phase
            if train:
                loss.backward()
                optimizer.step()

        # statistics
        total_loss += loss.item() * x.shape[0]
        total_correct += torch.sum(preds == torch.argmax(y_device, axis = 1))

        # total_correct += torch.sum(preds == y_device)
        total += x.shape[0]
        steps_per_epoch += 1

    
    if train:
        scheduler.step()
        
    return total_loss/total, total_correct/total, total



def eval_ensemble(models, loader, device = 'cuda:0', temp = 1.0):
    for model in models:
        model.eval()

    
    total_correct = 0
    total = 0
        
    for x, y in tqdm(loader):
        x_device = x.to(device)
        y_device = y.to(device)

        with torch.set_grad_enabled(False):
            probs = torch.zeros_like(y_device)

            for model in models:
                probs += nn.functional.softmax(model(x_device)/temp, 1)/len(models)
                # probs += model(x_device)/len(models)

            output = probs

        _, preds = torch.max(output, 1)
        # statistics
        total_correct += torch.sum(preds == torch.argmax(y_device, axis = 1))

        # total_correct += torch.sum(preds == y_device)
        total += x.shape[0]


    return total_correct/total, total

class DoubleImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample1 = self.transform(sample)
            sample2 = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample1, sample2, target
    
    
    
    
    

class ImageFolderWithPath(datasets.ImageFolder):
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample1 = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample1, target, path
    
    
class ImageFolderWithSoftLabel(datasets.ImageFolder):
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        label_path = path.replace('/images/', '/labels/')
        label = pickle.load(open(label_path[:-3] + 'pkl', 'rb'))
        if self.transform is not None:
            sample1 = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample1, torch.tensor(label)