import numpy as np
from PIL import Image
from math import inf
from scipy import stats


import torch
import torch.utils.data as data
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms

from model import Layer9
from Lenet import Lenet

from Resnet import ResNet18, ResNet34
from utils import noisify_multiclass_symmetric, noisify_pairflip


# From Co-teaching, Adjust learning rate and betas for Adam Optimizer
mom1 = 0.9
mom2 = 0.1
alpha_plan = [0.001] * 200
beta1_plan = [mom1] * 200
epoch_decay_start = 50  # 100

for i in range(epoch_decay_start, 200):
    alpha_plan[i] = float(200 - i) / (200 - epoch_decay_start) * 0.001
    beta1_plan[i] = mom2


# Only for Co-teaching
def adjust_learning_rate(optimizer, epoch):
    for param_group in optimizer.param_groups:
        param_group['lr'] = alpha_plan[epoch]
        param_group['betas'] = (beta1_plan[epoch], 0.999)  # Only change beta1


def get_instance_noisy_label(n, newdataset, labels, num_classes, feature_size, norm_std, seed):
    # n -> noise_rate
    # dataset -> mnist, cifar10, cifar100 # not train_loader
    # labels -> labels (targets)
    # label_num -> class number
    # feature_size -> the size of input images (e.g. 28*28)
    # norm_std -> default 0.1
    # seed -> random_seed
    label_num = num_classes
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))
    torch.cuda.manual_seed(int(seed))

    P = []
    flip_distribution = stats.truncnorm((0 - n) / norm_std, (1 - n) / norm_std, loc=n, scale=norm_std)
    # flip_distribution = stats.beta(a=0.01, b=(0.01 / n) - 0.01, loc=0, scale=1)
    flip_rate = flip_distribution.rvs(labels.shape[0])

    if isinstance(labels, list):
        labels = torch.FloatTensor(labels)
    if torch.cuda.is_available():
        labels = labels.cuda()

    W = np.random.randn(label_num, feature_size, label_num)
    if torch.cuda.is_available():
        W = torch.FloatTensor(W).cuda()
    else:
        W = torch.FloatTensor(W)
    print(W.shape)
    for i, (x, y) in enumerate(newdataset):
        # print(i, len(x))
        # 1*m *  m*10 = 1*10
        if torch.cuda.is_available():
            x = x.cuda()
        A = x.view(1, -1).mm(W[y]).squeeze(0)
        A[y] = -inf
        A = flip_rate[i] * F.softmax(A, dim=0)
        A[y] += 1 - flip_rate[i]
        P.append(A)
    P = torch.stack(P, 0).cpu().numpy()
    l1 = [i for i in range(label_num)]
    new_label = [np.random.choice(l1, p=P[i]) for i in range(labels.shape[0])]

    # np.save(file_path, np.array(new_label))
    print(f'noise rate = {(new_label != np.array(labels.cpu())).mean()}')

    record = [[0 for _ in range(label_num)] for i in range(label_num)]

    for a, b in zip(labels, new_label):
        a, b = int(a), int(b)
        record[a][b] += 1
        #
    print('****************************************')
    print('following is flip percentage:')

    for i in range(label_num):
        sum_i = sum(record[i])
        for j in range(label_num):
            if i != j:
                print(f"{record[i][j] / sum_i: .2f}", end='\t')
            else:
                print(f"{record[i][j] / sum_i: .2f}", end='\t')
        print()

    pidx = np.random.choice(range(P.shape[0]), 1000)
    cnt = 0
    for i in range(1000):
        if labels[pidx[i]] == 0:
            a = P[pidx[i], :]
            for j in range(label_num):
                print(f"{a[j]:.2f}", end="\t")
            print()
            cnt += 1
        if cnt >= 10:
            break
    return np.array(new_label)


# flip clean labels to noisy labels
# train set and val set split
def dataset_split(train_images, train_labels, noise_rate=0.5, noise_type='symmetric', split_per=0.9, random_seed=1, num_classes=10):
    clean_train_labels = train_labels[:, np.newaxis]
    if(noise_type == 'pairflip'):
        noisy_labels, real_noise_rate, transition_matrix = noisify_pairflip(clean_train_labels, noise=noise_rate,
                                                                            random_state=random_seed, nb_classes=num_classes)

    elif(noise_type == 'instance'):
        norm_std = 0.1
        # print("train_images.shape:", train_images.shape)
        if(len(train_images.shape) == 2):
            feature_size = train_images.shape[1]
        else:
            feature_size = 1
            for i in range(1, len(train_images.shape)):
                feature_size = int(feature_size * train_images.shape[i])

        # feature_size = train_images.shape[1]
        data = torch.from_numpy(train_images).float()
        targets = torch.from_numpy(train_labels)
        dataset = zip(data, targets)
        noisy_labels = get_instance_noisy_label(noise_rate, dataset, targets, num_classes, feature_size, norm_std, random_seed)
    else:
        noisy_labels, real_noise_rate, transition_matrix = noisify_multiclass_symmetric(clean_train_labels, noise=noise_rate,
                                                                                        random_state=random_seed, nb_classes=num_classes)

    noisy_labels = noisy_labels.squeeze()
#    print(noisy_labels)
    num_samples = int(noisy_labels.shape[0])
    np.random.seed(random_seed)
    train_set_index = np.random.choice(num_samples, int(num_samples*split_per), replace=False)
    index = np.arange(train_images.shape[0])
    val_set_index = np.delete(index, train_set_index)

    train_set, val_set = train_images[train_set_index, :], train_images[val_set_index, :]
    train_labels, val_labels = noisy_labels[train_set_index], noisy_labels[val_set_index]
    train_clean_labels, val_clean_labels = clean_train_labels[train_set_index], clean_train_labels[val_set_index]

    return train_set, val_set, train_labels, val_labels, train_clean_labels, val_clean_labels


# create different models for different datasets
def createModel(modelName, input_channel=3, num_classes=10):
    if modelName == 'Lenet':
        print('Building new Lenet(' + str(num_classes) + ')')
        model = Lenet()
    elif modelName == 'ResNet18':
        print('Building new ResNet18(' + str(num_classes) + ')')
        model = ResNet18(num_classes)
    elif modelName == 'ResNet34':
        print('Building new ResNet34(' + str(num_classes) + ')')
        model = ResNet34(num_classes)
    elif modelName == '9-layer':
        print('Building new 9-layer(' + str(input_channel) + ',' + str(num_classes) + ')')
        model = Layer9(input_channel=input_channel, n_outputs=num_classes)

    if torch.cuda.is_available():
        model.cuda()
    return model


# tools
def accuracy(logit, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    # print(logit)
    output = F.softmax(logit, dim=1)
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


# Evaluate the Model
def evaluate(test_loader, model, num_classes):
    model.eval()
    class_correct = np.zeros(num_classes)
    class_total = np.zeros(num_classes)

    with torch.no_grad():
        for images, labels, _ in test_loader:
            if torch.cuda.is_available():
                images = images.cuda()

            logits, _ = model(images)
            outputs = F.softmax(logits, dim=1)
            _, pred = torch.max(outputs.data, 1)
            preds = pred.cpu()

            for i in range(labels.size(0)):
                class_total[labels[i]] += 1
                if(preds[i] == labels[i]):
                    class_correct[labels[i]] += 1

    # To overcome the imbalance of noisy validation
    acc = np.average(100 * class_correct / class_total)
    std = np.std(100 * class_correct / class_total)

    print('Evaluating accuracy:', acc, 'Standard deviation:', std)
    return acc, std


def predict(train_loader, model1):
    model1.eval()    # Change model to 'eval' mode.
    preds = np.array([])

    for images, labels, _ in train_loader:
        if torch.cuda.is_available():
            images = Variable(images).cuda()
        logits1, _ = model1(images)
        outputs1 = F.softmax(logits1, dim=1)
        _, pred1 = torch.max(outputs1.data, 1)
        preds = np.concatenate((preds, pred1.to("cpu", torch.int).numpy()), axis=0)

    # print(preds.astype(int).tolist()[1])
    return preds.astype(int).tolist()


def train(train_loader, model1, optimizer1):
    model1.train()

    train_total = 0
    train_correct = 0

    for i, (images, labels, indexes) in enumerate(train_loader):
        if torch.cuda.is_available():
            images = Variable(images).cuda()
            labels = Variable(labels).cuda()

        # Forward + Backward + Optimize
        logits1, _ = model1(images)
        prec1, _ = accuracy(logits1, labels, topk=(1, 5))
        train_total += 1
        train_correct += prec1

        # forward
        loss_1 = F.cross_entropy(logits1, labels)
        # print(loss_1, logits1.size(), labels.size())

        optimizer1.zero_grad()
        loss_1.backward()
        optimizer1.step()

    train_acc1 = round(float(train_correct)/float(train_total), 2)
    print('Training accuracy:', train_acc1)
    return train_acc1


def isSame(cifar10_train_labels_clean, noise_labels, clean_labels, images, num_classes):
    labels = []
    data = []
    correct_number = 0
    class_correct = list(0. for i in range(num_classes))
    class_total = list(0. for i in range(num_classes))
    for i in range(len(cifar10_train_labels_clean)):
        if(cifar10_train_labels_clean[i] == noise_labels[i]):
            data.append(images[i])
            labels.append(cifar10_train_labels_clean[i])
            class_total[cifar10_train_labels_clean[i]] += 1

            if(cifar10_train_labels_clean[i] == clean_labels[i]):
                correct_number += 1
                class_correct[cifar10_train_labels_clean[i]] += 1

    accracy = round(100 * correct_number/len(labels), 2)
    # print(len(cifar10_train_labels_clean), len(train_data))
    print('Same labels number:', len(labels), accracy)

    for i in range(num_classes):
        if(class_total[i] > 0):
            print('Accuracy of %5s : %.2f %%' % (i, 100 * class_correct[i] / class_total[i]))

    return data, labels


class NewDataset(data.Dataset):
    '''
    Define a simple new dataset for training
    '''
    def __init__(self, data, labels, transform=None, target_transform=None):
        self.train_data = data
        self.train_labels = labels

        if transform is None:
            self.transform = transforms.ToTensor()
        else:
            self.transform = transform

        self.target_transform = target_transform

    def __getitem__(self, index):
        img, target = self.train_data[index], self.train_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

    def __len__(self):
        return len(self.train_data)
