import os
import math
import time
import random
import shutil
import argparse
import builtins
import warnings
import numpy as np

import torch
import torch.nn 
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torchvision.transforms as transforms 

from utils.models import linear, mlp
from cifar_models import densenet, convnet
from cifar_models.resnet import CIFAR_ResNet
from utils.utils_loss import *
from utils.utils_algo import *
from datasets.cifar10 import load_cifar10
from datasets.cifar100 import load_cifar100


def train(args, epoch, train_loader, net, loss_fn, optimizer):
    # conf_ema_m = 1. * epoch / args.epochs * (0.8 - 0.95) + 0.95
    conf_ema_m = args.conf_ema_m
    total_num = 0
    bingo_num = 0
    total_indexes = []
    total_plabels = []
    total_dlabels = []
    total_classfy_out = []
    total_classfy_logit = []
    piror_set = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]
    classfy_piror_set_bingo_num = [0,0,0,0,0,0,0,0,0,0,0]
    class_labels_bingo_num = 0
    margin = []
    net.train()
    ###### ---------------- one epoch training ---------------- ######
    #for i, (images, images_w1, images_w2, images_w3, images_s1, images_s2, images_s3, plabels, dlabels, index) in enumerate(train_loader): 
    #for i, (images_w1, plabels, dlabels, index) in enumerate(train_loader):
    for i, (images,plabels, dlabels, index) in enumerate(train_loader):

        images= images.cuda()
        plabels = plabels.float().cuda()
        dlabels = dlabels.long().detach().cuda() # only for evalaution
        index = index.cuda()
        outputs = net(images)[0]
        classfy_out = F.softmax(outputs, dim=1)
       
        
        for jj in range(len(piror_set)):
            classfy_piror_set_bingo_num[jj] = classfy_piror_set_bingo_num[jj] + torch.eq(torch.max(classfy_out * (plabels +piror_set[jj]*(1-plabels)),1)[1], dlabels).sum().cpu()
        margin += ((torch.max(classfy_out*plabels, 1)[0])/(1e-9+torch.max(classfy_out *(1-plabels), 1)[0])).tolist()
        total_num += plabels.size(0)
        bingo_num  += torch.eq(torch.max(classfy_out, 1)[1], dlabels).sum().cpu()
        total_indexes.append(index.detach().cpu().numpy())
        total_plabels.append(plabels.detach().cpu().numpy())
        total_dlabels.append(dlabels.detach().cpu().numpy())
        total_classfy_out.append(classfy_out.detach().cpu().numpy())
        #total_classfy_logit.append(outputs.detach().cpu().numpy())

        # calculate loss and update model

        average_loss = loss_fn(outputs, args.confidence, index,args.select,args.piror,plabels)
           
        optimizer.zero_grad()
        average_loss.backward()
        optimizer.step()
        class_labels_bingo_num += torch.eq(torch.max(classfy_out * (plabels +args.piror*(1-plabels)),1)[1], dlabels).sum().cpu()
        # update args.confidence

        if args.piror>0:
            with torch.no_grad():
                _, prot_pred = (classfy_out * (plabels + args.piror*(1 -plabels))).max(dim=1)
                pseudo_label = F.one_hot(prot_pred, plabels.shape[1]).float().cuda().detach()
                args.confidence[index, :] = conf_ema_m * args.confidence[index, :] + (1 - conf_ema_m) * pseudo_label
        else:
            with torch.no_grad():
                pesudo_label  = (classfy_out * plabels).float().cuda().detach()
                base_value = pesudo_label.sum(dim=1).unsqueeze(1).repeat(1, pesudo_label.shape[1])
                args.confidence[index, :] = pesudo_label/base_value



    temp_dlabels = train_loader.dataset.dlabels.astype('int')
    temp_plabels = train_loader.dataset.plabels
    num_sample = len(temp_dlabels)
    epoch_partial_rate = np.mean(np.sum(temp_plabels, axis=1))
    epoch_bingo_rate = np.sum(temp_plabels[np.arange(num_sample), temp_dlabels] == 1.0)/num_sample

    epoch_train_acc = bingo_num/total_num
    total_indexes = np.concatenate(total_indexes)
    total_plabels = np.concatenate(total_plabels)
    total_dlabels = np.concatenate(total_dlabels)
    total_classfy_out = np.concatenate(total_classfy_out)
    #total_classfy_logit = np.concatenate(total_classfy_logit)
    epoch_class_label_acc = class_labels_bingo_num/total_num
    
    for jj in range(len(piror_set)):
        classfy_piror_set_bingo_num[jj]=classfy_piror_set_bingo_num[jj]/total_num
    
    if epoch>=args.piror_start :
        if args.noise_rate>0:
            if args.piror_auto == 'case1':
                args.piror = sorted(margin)[int(len(margin)*args.noise_rate)]
            else:
                args.piror = min(args.piror+args.piror_add,args.piror_max)
                
        else:
            args.piror = 0
    args.piror = min(args.piror_max,args.piror)
    print (piror_set)
    print (classfy_piror_set_bingo_num)    
    
    print (f'Epoch:{epoch}/{args.epochs} Train classification acc={epoch_train_acc:.4f} UpdataACC:{epoch_class_label_acc:.4f} Piror:{args.piror:.4f}')

    train_save = {
        'epoch_train_acc':      epoch_train_acc,
        'epoch_bingo_rate':     epoch_bingo_rate,
        'epoch_partial_rate':   epoch_partial_rate,
        'total_indexes':        total_indexes,
        'total_plabels':        total_plabels,
        'total_dlabels':        total_dlabels,
        'total_classfy_out':    total_classfy_out,
        #'total_classfy_logit':  total_classfy_logit,
    }
    return train_save


def test(args, epoch, test_loader, net):
    
    bingo_num = 0
    total_num = 0
    test_probs = []
    test_preds = []
    test_labels = []
    test_hidden = []
    
    net.eval()
    for images, dlabels in test_loader:
        images = images.cuda()
        dlabels = dlabels.cuda()
        outputs, hiddens = net(images)
        outputs = F.softmax(outputs, dim=1)
        _, pred = torch.max(outputs.data, 1) 
        total_num += images.size(0)
        bingo_num += (pred == dlabels).sum().item()
        test_preds.append(pred.cpu().numpy())
        test_probs.append(outputs.detach().cpu().numpy())
        test_labels.append(dlabels.cpu().numpy())
        test_hidden.append(hiddens.detach().cpu().numpy())

    epoch_test_acc = bingo_num / total_num
    print(f'Epoch={epoch}/{args.epochs} Test accuracy={epoch_test_acc:.4f}, bingo_num={bingo_num},  total_num={total_num}')
    test_probs = np.concatenate(test_probs)
    test_preds = np.concatenate(test_preds)
    test_labels = np.concatenate(test_labels)
    test_hidden = np.concatenate(test_hidden)
    test_save = {
        'test_probs':       test_probs,
        'test_preds':       test_preds,
        'test_labels':      test_labels,
        'test_hidden':      test_hidden,
        'epoch_test_acc':   epoch_test_acc,
    }

    return epoch_test_acc, test_save


if __name__=='__main__':

    parser = argparse.ArgumentParser(description='PLL Baseline Model')
    
    parser.add_argument('--piror_start_auto', action='store_true', default=False, help='whether auto select correct_start')
    parser.add_argument('--piror', default=0, type=float, help = 'for Alim, store lambda')
    parser.add_argument('--piror_auto', default='case1', type=str, help = 'adapt adjust or fix lambda')
    parser.add_argument('--piror_start', default=2000, type=int, help = 'estart (when we apply Alim)')
    parser.add_argument('--piror_add', default=0, type=float, help = 'for fix lambda or ')
    parser.add_argument('--piror_max', default=1, type=float, help = 'for fix lambda')
    parser.add_argument('--conf_ema_m', default=0.9, type=float, help = 'for case 3')
    parser.add_argument('--co_up_type', default='case1', type=str, help = 'for case 3')
    parser.add_argument('--select', default=-1, type=float, help = 'No select sample')
    parser.add_argument('--noisy_type', default='flip', type=str, help='flip or pico')

    ## input parameters
    parser.add_argument('--dataset', default='cifar10', type=str, help='dataset name (cifar10)')
    parser.add_argument('--dataset_root', default='./dataset/CIFAR10', type=str, help='dataset root')
    parser.add_argument('--partial_rate', default=0.0, type=float, help='ambiguity level (q)')
    parser.add_argument('--noise_rate', default=0.0, type=float, help='noise level (gt may not in partial set)')
    parser.add_argument('--ood_rate', default=0.0, type=float, help='OOD level (gt not in the pre-defined label space)')
    parser.add_argument('--ood_dataset', default='none', type=str, help='OOD dataset')
    parser.add_argument('--workers', default=10, type=int, help='number of data loading workers')
    parser.add_argument('--batch_size', default=128, type=int, help='mini-batch size')
    
    ## model parameters
    parser.add_argument('--encoder', default='convnet', type=str, help='encoder: resnet, mlp, ...')
    parser.add_argument('--low_dim', default=128, type=int, help='embedding dimension for resnet')
    parser.add_argument('--dropout_rate', default=0.25, type=float, help='dropout rate for convnet')
    parser.add_argument('--num_class', default=10, type=int, help='number of classes in the dataset.')
    parser.add_argument('--loss_type', help='specify a loss function', default='rc', type=str)
    parser.add_argument('--lws_weight1', help='weight for first  item in [lws, lwc]', default=1, type=float)
    parser.add_argument('--lws_weight2', help='weight for second item in [lws, lwc]', default=1, type=float)

    ## training parameters
    parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
    parser.add_argument('--lr_adjust', default='Case1', type=str, help='Learning rate adjust manner: Case1 or Case2.')
    parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay (default: 1e-5).')
    parser.add_argument('--epochs', default=1000, type=int, help='number of total epochs to run')
    parser.add_argument('--decaystep', help='learning rate\'s decay step', type=int, default=10) # adjust learning rate
    parser.add_argument('--decayrate', help='learning rate\'s decay rate', type=float, default=0.9)
    parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
    parser.add_argument('--seed', help='seed', type=int, default=0)
    parser.add_argument('--savewhole', action='store_true', default=False, help='whether to save whole results')
    parser.add_argument('--save_root', help='where to save results', default='./savemodels', type=str)


    args = parser.parse_args()
    print(args)
    cudnn.benchmark = True
    torch.set_printoptions(precision=2, sci_mode=False)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.cuda.set_device(args.gpu)
    random.seed(args.seed)



    print (f'====== Step1: Reading Data =======')
    train_loader, train_givenY, test_loader = [], [], []
    if args.dataset == 'cifar10':
        input_channels = 3
        args.num_class = 10
        train_loader, train_givenY, test_loader = load_cifar10(args)
    elif args.dataset == 'cifar100':
        input_channels = 3
        args.num_class = 100
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])
        train_loader, train_givenY, test_loader = load_cifar100(args, transform)
    print (f'training samples: {len(train_loader.dataset)}')
    print (f'testing samples: {len(test_loader.dataset)}')


    print (f'====== Step2: Gaining model and optimizer =======')
    # loss function and args.confidence
    train_givenY = torch.FloatTensor(train_givenY)
    if args.loss_type in ['rc', 'proden','naive']:
        tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1])
        args.confidence = train_givenY.float() / tempY
        args.confidence = args.confidence.cuda()
        loss_fn = rc_loss
    elif args.loss_type == 'cc':
        loss_fn = cc_loss
    elif args.loss_type == 'lws':
        n, c = train_givenY.shape[0], train_givenY.shape[1]
        args.confidence = torch.ones(n, c) / c # generate args.confidence with all ones
        args.confidence = args.confidence.cuda()
        loss_fn = lws_loss
    elif args.loss_type == 'lwc':
        n, c = train_givenY.shape[0], train_givenY.shape[1]
        args.confidence = torch.ones(n, c) / c # generate args.confidence with all ones
        args.confidence = args.confidence.cuda()
        loss_fn = lwc_loss
    elif args.loss_type == 'mae':
        loss_fn = mae_loss
    elif args.loss_type == 'mse':
        loss_fn = mse_loss
    elif args.loss_type == 'ce':
        loss_fn = ce_loss
    elif args.loss_type == 'gce':
        loss_fn = gce_loss
    elif args.loss_type == 'phuber_ce':
        loss_fn = phuber_ce_loss
    elif args.loss_type == 'log':
        loss_fn = log_loss
    elif args.loss_type == 'exp':
        loss_fn = exp_loss

    # encoder
    if args.encoder == 'linear':
        net = linear(n_inputs=num_features, n_outputs=args.num_class)
    elif args.encoder == 'mlp':
        net = mlp(n_inputs=num_features, n_outputs=args.num_class)
    elif args.encoder == 'convnet':
        net = convnet(input_channels=input_channels, n_outputs=args.num_class, dropout_rate=args.dropout_rate)
    elif args.encoder == 'resnet':
        net = CIFAR_ResNet(feat_dim=args.low_dim, num_class=args.num_class)
    elif args.encoder == 'densenet':
        net = densenet(num_classes=args.num_class)
    net.cuda()
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, 
                                momentum=0.9, 
                                weight_decay=args.weight_decay, )


    print (f'====== Step3: Training and Evaluation =======')
    test_accs = []
    all_labels = []
    for epoch in range(1, args.epochs+1):
        if args.lr_adjust == 'case1':
            adjust_learning_rate_V1(args, optimizer, epoch)
        elif args.lr_adjust == 'case2':
            adjust_learning_rate_V2(args, optimizer, epoch)
        train_save = train(args, epoch, train_loader, net, loss_fn, optimizer)
        test_acc, test_save = test(args, epoch, test_loader, net)
        test_accs.append(test_acc)
        all_labels.append({'epoch_train_acc':    train_save['epoch_train_acc'],
                           'epoch_bingo_rate':   train_save['epoch_bingo_rate'],
                           'epoch_partial_rate': train_save['epoch_partial_rate'],
                           'epoch_test_acc' :    test_save['epoch_test_acc'],
                           })
        if args.savewhole and epoch%10==0: # further save data which occupy much space
            # all_labels[-1]['total_indexes'] = train_save['total_indexes']
            all_labels[-1]['total_plabels'] = train_save['total_plabels']
            all_labels[-1]['total_dlabels'] = train_save['total_dlabels']
            # all_labels[-1]['total_classfy_out'] = train_save['total_classfy_out']
            all_labels[-1]['total_classfy_logit'] = train_save['total_classfy_logit']
            # all_labels[-1]['test_probs']  = test_save['test_probs']
            # all_labels[-1]['test_preds']  = test_save['test_preds']
            # all_labels[-1]['test_labels']  = test_save['test_labels']
            # all_labels[-1]['test_hidden']  = test_save['test_hidden']


    print (f'====== Step4: Saving =======')
    if args.loss_type in ['rc', 'proden', 'lws', 'lwc']:
        args.confidence = args.confidence.detach().cpu().numpy()
    save_root = args.save_root
    if not os.path.exists(save_root): os.makedirs(save_root)

    ## gain suffix_name
    modelname = 'origin' if args.piror_start > args.epochs else 'correct'
    suffix_name = f'{args.dataset}_modelname:{modelname}+{args.loss_type}_plrate:{args.partial_rate}_noiserate:{args.noise_rate}_model:{args.encoder}'
    ## gain res_name
    best_index = np.argmax(np.array(test_accs))
    bestacc = test_accs[best_index]
    res_name = f'testacc:{bestacc}'

    save_path = f'{save_root}/{suffix_name}_{res_name}_{time.time()}.npz'
    print (f'save results in {save_path}')
    np.savez_compressed(save_path,
                        args=np.array(args, dtype=object),
                        all_labels=all_labels)