import os
import torch.utils.data as Data
import torchvision
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import argparse
from matplotlib.pyplot import MultipleLocator

from RC_loss import RCCC_loss
from datasets import make_training_dataset
from helper import load_dataset, get_matrix, get_model, get_loss
from datasets import gen_index_dataset
from datasets import cut_bag

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("gpu")


def get_args():
    parser = argparse.ArgumentParser(
        description='RCM & CCM implementation',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--set_size', help="number of instance contained in one set", type=int, default=6000)
    parser.add_argument('--set_number', help="number of sets", type=int, default=10)
    parser.add_argument('--class_number', help="number of class k" ,type=int, default=10)
    parser.add_argument('--matrix', type=str,help="type of prior matrix", choices = ['dia_dominate_matrix','random','random_20','off_same'],default='dia_dominate_matrix')
    parser.add_argument('--dataset', type=str,help="used dataset", default='mnist')
    parser.add_argument('--model', type=str, help="used classification model",default='mlp')
    parser.add_argument('--loss', type=str, help="used method", choices = ["CC","RC"], default='CC')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--pre_epoch', type=int, default=0)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--gpu', type=str, default='0') 
    parser.add_argument('--cut_rate',type=float, default=1)
    args = parser.parse_args()
    return args


def experiment(args):
    print(args)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)  
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available():
        device = torch.device('cuda:'+args.gpu)
    else:
        device = torch.device("cpu")  
    (x_train, y_train), (x_test, y_test) = load_dataset(args.dataset)
    matrix = get_matrix(args.matrix,  args.set_number)
    X, Y, S = make_training_dataset(((x_train, y_train), (x_test, y_test)), args.set_size, matrix,
                                    args.set_number, args.class_number)
    if args.cut_rate < 1:
        X, Y, S = cut_bag(X,Y,S,args.cut_rate)
    torch_dataset = gen_index_dataset(X, S, Y)

    #x_test = np.squeeze(x_test)
    #x_test = x_test.to(device)
    test_dataset = Data.TensorDataset(x_test, y_test)
    

    train_loader = torch.utils.data.DataLoader(torch_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=0)
    model = get_model(args.model).to(device)
    loss_fn = get_loss(args.loss, matrix, device)
    if args.loss == 'RC' or args.loss == 'CC':
        loss_fn.init_prior(train_loader,args.set_number)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    test_acc_list = []
    for epoch in range(args.epoch):
        for step, (b_x, b_y, b_true_y,index) in enumerate(train_loader):
            #b_x = np.squeeze(b_x)reshape((args.batch_size, -1))
            b_x = b_x.to(device)
            b_y,index = b_y.to(device),index.to(device)
            outputs = model(b_x)
            if args.loss == 'CC':
                loss = loss_fn.cc_loss(outputs,b_y)
            elif args.loss == 'RC':
                if epoch+1 <= args.pre_epoch:
                    loss = loss_fn.cc_loss(outputs,b_y)
                else:
                    loss = loss_fn.rc_loss(outputs,b_y,index)
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if args.loss == 'RC':
                with torch.no_grad():
                    model.eval()
                    outputs = model(b_x)
                    loss_fn.update_condifence(index,outputs)
                    model.train()

        """----------------------------------test-----------------------------------"""
        with torch.no_grad():
            model.eval()
            total, num_samples = 0, 0
            for test_x,test_y in test_loader:
                test_x = test_x.to(device)
                y_test_pred = model(test_x).cpu()
                y_test_pred = torch.tensor(torch.max(y_test_pred, 1)[1].data.numpy(), dtype=torch.float64)
                y_diff = y_test_pred - test_y
                total += torch.tensor(len(y_diff[y_diff == 0]), dtype=torch.float64)
                num_samples += test_y.shape[0]
        test_acc = total/num_samples
        if epoch >= args.epoch - 10:
            test_acc_list.append(test_acc)
        print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.cpu().numpy(), '| test acc: %.4f' % test_acc)
    return np.mean(test_acc_list)


if __name__ == '__main__':
    args = get_args()
    print("dataset: {}".format(args.dataset))
    print("loss: {}".format(args.loss))

    error = experiment(args)
    print("Final loss:{}".format(error))











