import torch
import random
import numpy as np
import os
import math

import matplotlib.pyplot as plt
plt.style.use('ggplot')


def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(seed)
    
def soft_target_cross_entropy(x, target, reduction='mean'):
    loss = torch.sum(-target * torch.nn.functional.log_softmax(x, dim=-1), dim=-1)
    if reduction == 'mean':
        return loss.mean()
    else:
        return loss
    
def v9_input_target_max_all(h1,t1,h0,t0):

    if len(t0.shape) == 2:
        t0 = t0.flatten(1)
        yy = t0 @ t0.t()
    else:
        t0 = t0.flatten(2).permute(2,0,1)
        yy = (t0 @ t0.permute(0,2,1)).mean(0)

    if len(t1.shape) > 2:
        h1 = h1.flatten(2).permute(2,0,1)
        t1 = t1.flatten(2).permute(2,1,0)
        y = (h1 @ t1).mean(0)

    else:
        h1 = h1.flatten(1)
        t1 = t1.flatten(1)

        y = h1 @ t1.t()

    yym = yy.max(1,keepdim=True)[0]
    yy = (yy == yym).float()

    l = soft_target_cross_entropy(y, yy)

    return l

class v15_input_target_topk(object):
    def __init__(self, topk):
        super().__init__()
        self.topk = topk

    def __call__(self,sp_learn,h1,t1,h0,t0,context):

        if len(t0.shape) == 2:
            t0 = t0.flatten(1)
            yy = t0 @ t0.t()
        else:
            t0 = t0.flatten(2).permute(2,0,1)
            yy = (t0 @ t0.permute(0,2,1)).mean(0)

        if len(t1.shape) > 2:
            h1 = h1.flatten(2).permute(2,0,1)
            t1 = t1.flatten(2).permute(2,1,0)
            y = (h1 @ t1).mean(0)

        else:
            h1 = h1.flatten(1)
            t1 = t1.flatten(1)

            y = h1 @ t1.t()

        yym = yy.topk(self.topk,dim=1)[1]
        yy = torch.nn.functional.one_hot(yym,yy.shape[1]).sum(1).clamp(0,1).float()

        l = soft_target_cross_entropy(y, yy)

        return 
    
def v14_input_target_max_rand(h1,t1,h0,t0):

    if len(t0.shape) == 2:
        t0 = t0.flatten(1)
        yy = t0 @ t0.t()
    else:
        t0 = t0.flatten(2).permute(2,0,1)
        yy = (t0 @ t0.permute(0,2,1)).mean(0)

    if len(t1.shape) > 2:
        h1 = h1.flatten(2).permute(2,0,1)
        t1 = t1.flatten(2).permute(2,1,0)
        y = (h1 @ t1).mean(0)

    else:
        h1 = h1.flatten(1)
        t1 = t1.flatten(1)

        y = h1 @ t1.t()

    l  = torch.nn.functional.cross_entropy(y, yy.argmax(1))

    return l



def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()

        _, pred_ground = target.topk(maxk, 1, True, True)
        pred_ground = pred_ground.t()
        # correct = pred.eq(target.reshape(1, -1).expand_as(pred))
        correct = pred.eq(pred_ground)

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size).item())
        return res


class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """

    def __init__(
        self, name='', best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        self.name = name

    def __call__(
        self, current_valid_loss,
        epoch, model, optimizer, criterion
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
            }, f'Output/{self.name}_best_model.pth')


def save_model(epochs, model, optimizer, criterion):
    """
    Function to save the trained model to disk.
    """
    print(f"Saving final model...")
    torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': criterion,
    }, 'Output/final_model.pth')


def save_plots(name, train_acc, valid_acc, train_loss, valid_loss):
    """
    Function to save the loss and accuracy plots to disk.
    """
    # accuracy plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_acc, color='green', linestyle='-',
        label='train accuracy'
    )
    plt.plot(
        valid_acc, color='blue', linestyle='-',
        label='validation accuracy'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(f'Output/{name}_accuracy.png')

    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss, color='orange', linestyle='-',
        label='train loss'
    )
    plt.plot(
        valid_loss, color='red', linestyle='-',
        label='validation loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f'Output/{name}_loss.png')


def cosine_annealing_lr_with_warmup(epoch, num_epochs, initial_lr, warmup_epochs=10, start_lr=0.0001, adaptive=1):
    """
    Cosine annealing learning rate schedule with warm-up.
    """
    max_epochs = num_epochs
    if epoch < warmup_epochs:
        # Linearly increase the learning rate
        #lr = start_lr + ((initial_lr - start_lr) / warmup_epochs) * epoch
        # sine increase the learning rate
        lr = start_lr + (initial_lr - start_lr) * (1 - math.cos(math.pi * epoch / warmup_epochs)) / 2
    else:
        # Shifted epoch value to account for the warm-up period
        shifted_epoch = epoch - warmup_epochs
        max_epochs -= warmup_epochs
        # Cosine annealing after the warm-up period
        if adaptive == 1:
            lr = 0.5 * initial_lr * (1 + math.cos(math.pi * shifted_epoch / max_epochs))
        else:
            lr = initial_lr

    return lr

# class CIFAR10Dataset(Dataset):
#     def __init__(self, data, labels, tf_augmentations=None):
#         self.data = data
#         self.labels = labels
#         self.tf_augmentations = tf_augmentations

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         image = self.data[idx]
        
#         # Apply TensorFlow-based augmentations
#         if self.tf_augmentations:
#             image = tf.convert_to_tensor(image, dtype=tf.float32)
#             image = self.tf_augmentations.flow(image.numpy().reshape((1,) + image.shape), batch_size=1).next()[0]
        
#         image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) # C, H, W format
#         label = torch.tensor(self.labels[idx], dtype=torch.long)

#         return image, label
    

# def unpickle(file):
#     """load the cifar-10 data"""

#     with open(file, 'rb') as fo:
#         data = pickle.load(fo, encoding='bytes')
#     return data


# def load_cifar_10_data(data_dir, negatives=False):
#     """
#     Return train_data, train_filenames, train_labels, test_data, test_filenames, test_labels
#     """

#     # get the meta_data_dict
#     # num_cases_per_batch: 1000
#     # label_names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
#     # num_vis: :3072

#     meta_data_dict = unpickle(data_dir + "/batches.meta")
#     cifar_label_names = meta_data_dict[b'label_names']
#     cifar_label_names = np.array(cifar_label_names)

#     # training data
#     cifar_train_data = None
#     cifar_train_filenames = []
#     cifar_train_labels = []

#     for i in range(1, 6):
#         cifar_train_data_dict = unpickle(data_dir + "/data_batch_{}".format(i))
#         if i == 1:
#             cifar_train_data = cifar_train_data_dict[b'data']
#         else:
#             cifar_train_data = np.vstack((cifar_train_data, cifar_train_data_dict[b'data']))
#         cifar_train_filenames += cifar_train_data_dict[b'filenames']
#         cifar_train_labels += cifar_train_data_dict[b'labels']

#     cifar_train_data = cifar_train_data.reshape((len(cifar_train_data), 3, 32, 32))
#     if negatives:
#         cifar_train_data = cifar_train_data.transpose(0, 2, 3, 1).astype(np.float32)
#     else:
#         cifar_train_data = np.rollaxis(cifar_train_data, 1, 4)
#     cifar_train_filenames = np.array(cifar_train_filenames)
#     cifar_train_labels = np.array(cifar_train_labels)

#     cifar_test_data_dict = unpickle(data_dir + "/test_batch")
#     cifar_test_data = cifar_test_data_dict[b'data']
#     cifar_test_filenames = cifar_test_data_dict[b'filenames']
#     cifar_test_labels = cifar_test_data_dict[b'labels']

#     cifar_test_data = cifar_test_data.reshape((len(cifar_test_data), 3, 32, 32))
#     if negatives:
#         cifar_test_data = cifar_test_data.transpose(0, 2, 3, 1).astype(np.float32)
#     else:
#         cifar_test_data = np.rollaxis(cifar_test_data, 1, 4)
#     cifar_test_filenames = np.array(cifar_test_filenames)
#     cifar_test_labels = np.array(cifar_test_labels)

#     return cifar_train_data, cifar_train_filenames, tf.keras.utils.to_categorical(cifar_train_labels), \
#         cifar_test_data, cifar_test_filenames, tf.keras.utils.to_categorical(cifar_test_labels), cifar_label_names