import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

from utils import set_random_seed, get_minibatches_idx
from models import ResNet18, VGG, CNN
from data import save_train_data, save_test_data, load_data_from_pickle
from minority_collapse import analyze_collapse_new, analyze_dual, get_features

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
#torch.set_default_tensor_type('torch.cuda.FloatTensor')


def build_model(config):
    if config['model'] == 'ResNet18':
        model = ResNet18(color_channel=config['color_channel'],num_classes = config['n_class'])
    elif config['model'] == 'VGG11':
        model = VGG('VGG11', color_channel=config['color_channel'],num_classes = config['n_class'])
    elif config['model'] == 'VGG13':
        model = VGG('VGG13', color_channel=config['color_channel'],num_classes = config['n_class'])
    elif config['model'] == 'CNN':
        model = CNN()
    else:
        print('wrong model option')
        model = None
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=config['lr'],  momentum=config['momentum'],
                          weight_decay=config['weight_decay'])

    return model, loss_function, optimizer


def simple_train_batch(trainloader, model, loss_function, optimizer, config):
    model.train()
    for epoch in range(config['epoch_num']):
        if epoch == int(config['epoch_num'] / 3):
            for g in optimizer.param_groups:
                g['lr'] = config['lr'] / 10
            print('divide current learning rate by 10')
        elif epoch == int(config['epoch_num'] * 2 / 3):
            for g in optimizer.param_groups:
                g['lr'] = config['lr'] / 100
            print('divide current learning rate by 10')
        total_loss = 0
        minibatches_idx = get_minibatches_idx(len(trainloader), minibatch_size=config['simple_train_batch_size'],
                                              shuffle=True)
        for minibatch in minibatches_idx:
            inputs = torch.Tensor(np.array([list(trainloader[x][0].cpu().numpy()) for x in minibatch]))
            targets = torch.Tensor(np.array([list(trainloader[x][1].cpu().numpy()) for x in minibatch]))
            inputs, targets = Variable(inputs.cuda()).squeeze(1), Variable(targets.long().cuda()).squeeze()
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            #print(outputs.size())
            #print(targets.size())
            loss = loss_function(outputs, targets)
            #print(loss)
            total_loss += loss
            loss.backward()
            optimizer.step()
        print('epoch:', epoch, 'loss:', total_loss)

    # report the 5 indices for NC
    linear_weights = model.classifier.weight.cpu().data.numpy()
    # print('######### analyze weights #########', linear_weights.shape)
    w_norm_variation, w_cos_mean, w_cos_max = analyze_collapse_new(
        linear_weights, config, option='weights')
    class_features = get_features(trainloader, model, config)
    # print('######### analyze features #########')
    h_norm_variation, h_cos_mean, h_cos_max = analyze_collapse_new(
        class_features, config, option='features')
    # print('######### analyze the duality of weights and features #########')
    analyze_dual(linear_weights, class_features)
    print('w norm:', w_norm_variation, 'w cos:', w_cos_mean)
    print('h norm:', h_norm_variation, 'h cos:', h_cos_mean)





def simple_test_batch(testloader, model, config):
    model.eval()
    total = 0.0
    correct = 0.0
    minibatches_idx = get_minibatches_idx(len(testloader), minibatch_size=config['simple_test_batch_size'],
                                          shuffle=False)
    y_true = []
    y_pred = []
    for minibatch in minibatches_idx:
        inputs = torch.Tensor(np.array([list(testloader[x][0].cpu().numpy()) for x in minibatch]))
        targets = torch.Tensor(np.array([list(testloader[x][1].cpu().numpy()) for x in minibatch]))
        inputs, targets = Variable(inputs.cuda()).squeeze(1), Variable(targets.cuda()).squeeze()
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.long()).sum().item()
        y_true.extend(targets.cpu().data.numpy().tolist())
        y_pred.extend(predicted.cpu().data.numpy().tolist())
    test_accuracy = correct / total
    test_confusion_matrix = confusion_matrix(y_true, y_pred)
    t1 = config['t1']
    big_class_acc = np.sum([test_confusion_matrix[i, i] for i in range(t1)]) / np.sum(test_confusion_matrix[:t1])
    if t1 == 10:
        small_class_acc = None
    else:
        small_class_acc = \
            np.sum([test_confusion_matrix[i, i] for i in range(10)[t1:]]) / np.sum(test_confusion_matrix[t1:])
    return test_accuracy, big_class_acc, small_class_acc, test_confusion_matrix

#data_option = sys.argv[1].split('=')[1]
#model_option = sys.argv[2].split('=')[1]
#t1 = int(sys.argv[3].split('=')[1])
#R = sys.argv[4].split('=')[1]
#config = {'dir_path': '/content/drive/MyDrive', 'data': data_option, 'model': model_option,
#              't1': t1, 'R': R, 'simple_train_batch_size': 128, 'simple_test_batch_size': 100, 'epoch_num': 20,
#              'lr': 0.1, 'momentum': 0.9, 'weight_decay': 5e-4, 'fixed': 'big'}

def run_train_models(config):
    data_option = sys.argv[1].split('=')[1]
    model_option = sys.argv[2].split('=')[1]
    t1 = int(sys.argv[3].split('=')[1])
    R = sys.argv[4].split('=')[1]

    # fixed: big/small
    if data_option == 'fashion_mnist':
        config['color_channel'] = 1
    else:
        config['color_channel'] = 3
    if R == 'inf':
        config['big_class_sample_size'] = 5000
        config['small_class_sample_size'] = 0
    else:
        R = int(R)
        if data_option == 'cifar10':
            config['big_class_sample_size'] = 5000
            config['small_class_sample_size'] = 5000 // R
        elif data_option == 'cifar100':
            config['big_class_sample_size'] = 500
            config['small_class_sample_size'] = 500 // R
        elif data_option == 'fashion_mnist':
            config['big_class_sample_size'] = 6000
            config['small_class_sample_size'] = 6000 // R
        else:
            print('wrong data option')
    #model_save_name = 'classifier.pt'
    model_path = config['dir_path'] + '/models/' + config['data'] + '_' + config['model'] + '_t1=' + \
            str(config['t1']) + '_R=' + config['R'] + "_" + config['fixed'] + '.pt'
    #F"/content/drive/MyDrive/{model_save_name}"
    #torch.save(model.state_dict(), path)
    #config['dir_path'] + '/models/' + config['data'] + '_' + config['model'] + '_t1=' + \
                 #str(config['t1']) + '_R=' + config['R'] + "_" + config['fixed'] + '.pt'

    print('save test data')
    set_random_seed(666)
    save_test_data(config)

    print('save train data')
    set_random_seed(666)
    save_train_data(config)

    set_random_seed(666)
    print('load data from pickle')
    train_data, test_data = load_data_from_pickle(config)

    print('build model')
    model, loss_function, optimizer = build_model(config)
    print('train model')
    simple_train_batch(train_data, model, loss_function, optimizer, config)
    print('save model')
    torch.save(model.state_dict(), model_path)
    print('load model')
    model.load_state_dict(torch.load(model_path))
    train_res, train_big, train_small, train_confusion_matrix = simple_test_batch(train_data, model, config)
    test_res, test_big, test_small, test_confusion_matrix = simple_test_batch(test_data, model, config)
    print('train accuracy', train_res, train_big, train_small)
    print('test accuracy', test_res, test_big, test_small)
    print('train confusion matrix\n', train_confusion_matrix)
    print('test confusion matrix\n', test_confusion_matrix)

def run_train_models_new(config):
    data_option = sys.argv[1].split('=')[1]
    model_option = sys.argv[2].split('=')[1]
    t1 = int(sys.argv[3].split('=')[1])
    R = sys.argv[4].split('=')[1]

    # fixed: big/small
    if data_option == 'fashion_mnist':
        config['color_channel'] = 1
    else:
        config['color_channel'] = 3
    if R == 'inf':
        config['big_class_sample_size'] = 5000
        config['small_class_sample_size'] = 0
    else:
        R = int(R)
        if data_option == 'cifar10':
            config['big_class_sample_size'] = 5000
            config['small_class_sample_size'] = 5000 // R
        elif data_option == 'cifar100':
            config['big_class_sample_size'] = 500
            config['small_class_sample_size'] = 500 // R
        elif data_option == 'fashion_mnist':
            config['big_class_sample_size'] = 6000
            config['small_class_sample_size'] = 6000 // R
        else:
            print('wrong data option')
    #model_save_name = 'classifier.pt'
    model_path = config['dir_path'] + '/models/' + config['data'] + '_' + config['model'] + '_t1=' + \
            str(config['t1']) + '_R=' + config['R'] + "_" + config['fixed'] + '.pt'
    #F"/content/drive/MyDrive/{model_save_name}"
    #torch.save(model.state_dict(), path)
    #config['dir_path'] + '/models/' + config['data'] + '_' + config['model'] + '_t1=' + \
                 #str(config['t1']) + '_R=' + config['R'] + "_" + config['fixed'] + '.pt'

    # 5/25/2021
    print('save test data')
    set_random_seed(666)
    save_test_data(config)

    print('save train data')
    set_random_seed(666)
    save_train_data(config)

    set_random_seed(666)
    print('load data from pickle')
    train_data, test_data = load_data_from_pickle(config)

    print('build model')
    model, loss_function, optimizer = build_model(config)
    train_res, train_big, train_small, train_confusion_matrix = simple_test_batch(train_data, model, config)
    print('train confusion matrix\n', train_confusion_matrix)
    print('train model')
    w_norm_variation, w_cos_mean, h_norm_variation, h_cos_mean, dual, with_in = simple_train_batch_new(train_data, model, loss_function, optimizer, config)
    print('save model')
    torch.save(model.state_dict(), model_path)
    print('load model')
    model.load_state_dict(torch.load(model_path))
    train_res, train_big, train_small, train_confusion_matrix = simple_test_batch(train_data, model, config)
    test_res, test_big, test_small, test_confusion_matrix = simple_test_batch(test_data, model, config)
    print('train accuracy', train_res, train_big, train_small)
    print('test accuracy', test_res, test_big, test_small)
    print('train confusion matrix\n', train_confusion_matrix)
    print('test confusion matrix\n', test_confusion_matrix)

    return w_norm_variation, w_cos_mean, h_norm_variation, h_cos_mean, dual, with_in



def simple_train_batch_new(trainloader, model, loss_function, optimizer, config):
    model.train()
    w_norm_variation = [0] * config['epoch_num']
    w_cos_mean = [0] * config['epoch_num']
    w_cos_max = [0] * config['epoch_num']
    h_norm_variation = [0] * config['epoch_num']
    h_cos_mean = [0] * config['epoch_num']
    h_cos_max = [0] * config['epoch_num']
    dual = [0] * config['epoch_num']
    with_in = [0] * config['epoch_num']

    K = 1

    for epoch in range(config['epoch_num']):
        if epoch == int(config['epoch_num'] / 3):
            for g in optimizer.param_groups:
                g['lr'] = config['lr'] #/ 10
            print('divide current learning rate by 10')
        elif epoch == int(config['epoch_num'] * 2 / 3):
            for g in optimizer.param_groups:
                g['lr'] = config['lr'] #/ 100
            print('divide current learning rate by 10')
        total_loss = 0
        minibatches_idx = get_minibatches_idx(len(trainloader), minibatch_size=config['simple_train_batch_size'],
                                              shuffle=True)
        for minibatch in minibatches_idx:
            inputs = torch.Tensor(np.array([list(trainloader[x][0].cpu().numpy()) for x in minibatch]))
            targets = torch.Tensor(np.array([list(trainloader[x][1].cpu().numpy()) for x in minibatch]))
            inputs, targets = Variable(inputs.cuda()).squeeze(1), Variable(targets.long().cuda()).squeeze()
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            # print(outputs.size())
            # print(targets.size())
            loss = loss_function(outputs, targets)
            # print(loss)
            total_loss += loss
            loss.backward()
            optimizer.step()

        # report the 5 indices for NC
        if epoch % K == 0:
            linear_weights = model.classifier.weight.cpu().data.numpy()
            # print('######### analyze weights #########', linear_weights.shape)
            w_norm_variation[epoch//K], w_cos_mean[epoch//K], w_cos_max[epoch//K] = analyze_collapse_new(
                linear_weights, config, option='weights')
            class_features, with_in[epoch//K] = get_features(trainloader, model, config, option = 'features')
            # print('######### analyze features #########')
            h_norm_variation[epoch//K], h_cos_mean[epoch//K], h_cos_max[epoch//K] = analyze_collapse_new(
                class_features, config, option='features')
            # print('######### analyze the duality of weights and features #########')
            dual[epoch//K] = analyze_dual(linear_weights, class_features)
            print('w norm:', w_norm_variation[epoch//K], 'w cos:', w_cos_mean[epoch//K])
            print('h norm:', h_norm_variation[epoch//K], 'h cos:', h_cos_mean[epoch//K])
            print('with_in:', with_in[epoch//K])

        print('epoch:', epoch, 'loss:', total_loss)

    return w_norm_variation, w_cos_mean, h_norm_variation, h_cos_mean, dual, with_in


if __name__ == '__main__':
    # model_option: VGG13 / ResNet18
    # data_option: cifar10 / fashion_mnist
    sys.argv = ['', 'data_option=cifar10', 'model_option=VGG13', 't1=10', 'R=1']


    data_option = sys.argv[1].split('=')[1]
    model_option = sys.argv[2].split('=')[1]
    t1 = int(sys.argv[3].split('=')[1])
    R = sys.argv[4].split('=')[1]
    # need to specify path first
    config = {'dir_path': '/content/drive/MyDrive', 'data': data_option, 'model': model_option,
              't1': t1, 'R': R, 'simple_train_batch_size': 128, 'simple_test_batch_size': 100, 'epoch_num': 10,
              'lr': 0.01, 'momentum': 0.3, 'weight_decay': 0, 'fixed': 'big', 'n_class': 10}
    # per-epoch values
    # w_norm_variation: norm of classifier
    # w_cos_mean: cosine of classifier
    # h_norm_variation: norm of last-layer features
    # h_cos_mean: cosine of last-layer features
    # dual: distance between normalized classifier and normalized centered last-layer features
    # with_in: with-in class variations of last-layer features
    w_norm_variation, w_cos_mean, h_norm_variation, h_cos_mean, dual, with_in = run_train_models_new(config)
