# -*- coding:utf-8 -*-
import os
import random
import torch 
import torch.nn.functional as F
from torch.autograd import Variable
from data.datasets_clean import input_dataset
from models import *
from models import resnet32
import argparse
import numpy as np
from metrics import *
from torchvision import transforms
import time
from imbalance_cifar_group import IMBALANCECIFAR10, IMBALANCECIFAR100
from focalloss import *
from loss import loss_cross_entropy, loss_peer, loss_pls, loss_nls, logit_adj
from torch.utils.data import RandomSampler, DataLoader

import warnings
warnings.simplefilter("ignore")

parser = argparse.ArgumentParser()
parser.add_argument('--lr', type = float, default = 0.05)
parser.add_argument('--tail_rate', type = float, default = 0.1)
parser.add_argument('--loss', type = str, help = 'ce, gce, dmi, flc, uspl,spl,peerloss', default = 'ce')
parser.add_argument('--result_dir', type = str, help = 'dir to save result txt files', default = './results')
parser.add_argument('--dataset', type = str, help = ' cifar10 or cifar100', default = 'cifar10')
parser.add_argument('--model', type = str, help = 'resnet', default = 'resnet')
parser.add_argument('--n_epoch', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--print_freq', type=int, default=50)
parser.add_argument('--num_workers', type=int, default=0, help='how many subprocesses to use for data loading')
parser.add_argument('--noise_mode',  default='imb', help='imb or sym')
parser.add_argument('--r', default=0.4, type=float, help='noise ratio')
parser.add_argument('--data_path', default='data/cifar-10-batches-py', type=str, help='path to dataset')
parser.add_argument('--g_idx', type=int, default=2, help='num of sub-populations to consider')
parser.add_argument('--lmd', default=0.0, type=float, help='lmd')
parser.add_argument('--metric', type = str, help = 'metric', default = 'acc')
parser.add_argument('--conf', type = str, help = 'no_conf, entropy', default = 'no_conf')
parser.add_argument('--method', type = str, help = 'frf', default = 'frf')
parser.add_argument('--cluster_type', type = str, help = 'knn, random', default = 'random')

# Adjust learning rate and for SGD Optimizer
def adjust_learning_rate(optimizer, epoch,alpha_plan,loss_type='ce'):
    for param_group in optimizer.param_groups:
        param_group['lr']=alpha_plan[epoch]

def normalize_longtail(cls_num, imb_factor=0.01, imb_type='exp'):
    img_max = 50000 / cls_num
    img_num_per_cls = []
    if imb_type == 'exp':
        for cls_idx in range(cls_num):
            num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
            img_num_per_cls.append(int(num))
    elif imb_type == 'step':
        for cls_idx in range(cls_num // 2):
            img_num_per_cls.append(int(img_max))
        for cls_idx in range(cls_num // 2):
            img_num_per_cls.append(int(img_max * imb_factor))
    else:
        img_num_per_cls.extend([int(img_max)] * cls_num)
    return img_num_per_cls

def accuracy(logit, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    output = F.softmax(logit, dim=1)
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


# Train the Model
def train(epoch, num_classes, train_loader,peer_loader_x, peer_loader_y, model, optimizer,loss_type, lmd = 0.0):
    train_total=0
    train_correct=0
    peer_iter_x = iter(peer_loader_x)
    peer_iter_y = iter(peer_loader_y)
    for i, (images, labels, groups, indexes) in enumerate(train_loader):
        batch_size = len(indexes)

        if loss_type=='peerloss':
            x_peer, _, _, _ = next(peer_iter_x)
            _, label_peer, _, _ = next(peer_iter_y)
            x_peer = Variable(x_peer).cuda()
            label_peer = Variable(label_peer).cuda()

        images = images.cuda()
        labels = labels.cuda()
        groups = groups.cuda()
       
        # Forward + Backward + Optimize
        logits = model(images)
        if loss_type=='peerloss':
            logits_peer = model(x_peer)
        prec, _ = accuracy(logits, labels, topk=(1, 5))
        train_total+=1
        train_correct+=prec
        if loss_type=='ce':
            loss = loss_cross_entropy(epoch,logits, labels)
        elif loss_type=='peerloss':
            loss, loss_v = loss_peer(epoch,logits, logits_peer, labels, label_peer)
        elif loss_type == 'pls':
            loss = loss_pls(epoch,logits,labels)
        elif loss_type == 'nls':
            loss = loss_nls(epoch,logits,labels)
        elif loss_type=='focal':
            loss = FocalLoss(gamma=2.0)(logits, labels)
        elif loss_type=='logit_adj':
            loss = logit_adj(args.samples_per_cls,logits, labels)
        else:
            # print(loss_type)
            raise NotImplementedError(f" {loss_type} Not Implemented")
        
        # Add Fairness Regularization (FR)
        constraints_fair = constraints_dict[args.metric]
        constraints_confidence = constraints_dict[args.conf]
        loss_reg, _ = constraints_fair(logits, groups, labels, n_groups=args.g_idx)
        if args.method == 'frf':
            lmd = args.lmd
            loss += torch.sum((lmd / args.g_idx) * torch.abs(loss_reg))
        loss += constraints_confidence(logits)
        optimizer.zero_grad()
        # loss.backward()
        loss.backward(retain_graph=True)
        optimizer.step()
        if (i+1) % args.print_freq == 0:
            print ('Epoch [%d/%d], Iter [%d/%d] Training Accuracy: %.4F, Loss: %.4f'
                  %(epoch+1, args.n_epoch, i+1, len(train_dataset)//batch_size, prec, loss.data))

    train_acc=float(train_correct)/float(train_total)
    return train_acc, lmd

# Evaluate the Model
def evaluate(test_loader,model,save=False,epoch=0,best_acc_=0,args=None):
    model.eval()    
    print('previous_best', best_acc_)
    correct = 0
    total = 0
    for images, labels, _ in test_loader:
        images = Variable(images).cuda()
        labels = Variable(labels).cuda()

        logits = model(images)
        outputs = F.softmax(logits, dim=1)
        _, pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (pred == labels).sum()

    acc = 100*float(correct)/float(total)

    if save:
        if acc > best_acc_:
            state = {'state_dict': model.state_dict(),
                     'epoch':epoch,
                     'acc':acc,
            }
            save_path= os.path.join(save_dir + sub_folder + f'best_lmd{args.lmd}_{args.method}_{args.conf}.pth.tar')
            torch.save(state,save_path)
            best_acc_ = acc
            print(f'model saved to {save_path}!')
        if epoch == args.n_epoch -1:
            state = {'state_dict': model.state_dict(),
                     'epoch':epoch,
                     'acc':acc,
            }
            torch.save(state,os.path.join(save_dir + sub_folder + f'last_lmd{args.lmd}_{args.method}_{args.conf}.pth.tar'))
    return acc, best_acc_


#####################################main code ################################################
start_time = time.time()
args = parser.parse_args()
torch.set_num_threads(3)
run_id = 0
seed = run_id
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed) 
torch.backends.cudnn.deterministic = True
# Hyper Parameters
batch_size = 128
learning_rate = args.lr 
if args.dataset == 'cifar10':
    num_classes = 10
elif args.dataset == 'cifar100':
    num_classes = 100

args.samples_per_cls = normalize_longtail(cls_num=num_classes, imb_factor=args.tail_rate, imb_type='exp')
train_dataset,test_dataset,num_classes,num_training_samples = input_dataset(args.dataset)
print('train_labels:', len(train_dataset.train_labels), train_dataset.train_labels[:10])
# load model
print('building model...')
model = resnet32.resnet32(num_classes=num_classes)
print('building model done')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=0.0002, momentum=0.9)

### save result and model checkpoint #######   
save_dir = args.result_dir +'/' +args.dataset + '/'  +f'run_id{run_id}/' 


sub_folder = args.loss + args.noise_mode + f'{args.r}' + '_tr' + str(args.tail_rate) + f'_groups{args.g_idx}/'
if not os.path.exists(save_dir + sub_folder):
    os.makedirs(save_dir + sub_folder)

if args.dataset.endswith('10'):
    transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_dataset = IMBALANCECIFAR10(
        root=args.data_path,
        imb_type='exp', imb_factor=args.tail_rate, g_idx=args.g_idx, 
        noise_mode=args.noise_mode, type=args.cluster_type, noise_ratio=args.r,
        train=True, download=True, transform=transform_train
    )
    test_dataset = IMBALANCECIFAR10(
        root=args.data_path,
        train=False, download=True, transform=transform_val
    )
elif args.dataset.endswith('100'):
    transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
        ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ])
    train_dataset = IMBALANCECIFAR100(
        root=args.data_path,
        imb_type='exp', imb_factor=args.tail_rate,
        noise_mode=args.noise_mode, type=args.cluster_type, noise_ratio=args.r,
        train=True, download=True, transform=transform_train
    )
    test_dataset = IMBALANCECIFAR100(
        root=args.data_path,
        train=False, download=True, transform=transform_val
    )


train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)
            
                                  
peer_sampler_x = RandomSampler(train_dataset, replacement=True)
peer_loader_x = torch.utils.data.DataLoader(dataset = train_dataset,
                                    batch_size = args.batch_size,
                                    num_workers=0,
                                    shuffle=False,
                                    sampler=peer_sampler_x)

peer_loader_y = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)

alpha_plan = [0.02, 0.04, 0.06, 0.08, 0.1]  + [0.1] * 155  + [0.001] * 20 + [0.00001] * 20

if args.loss == 'nls':
    args.n_epoch = 120
    alpha_plan = [1e-6] * 120
   
    model_path = f"results/{args.dataset}/run_id{0}/ce{args.noise_mode}{args.r}_tr{args.tail_rate}_groups{2}/last_lmd{args.lmd}_{args.metric}_{args.conf}.pth.tar"
    if not os.path.exists(model_path):
        model_path = f"results/{args.dataset}/run_id{0}/ce{args.noise_mode}{args.r}_tr{args.tail_rate}_groups{2}/best_lmd{args.lmd}_{args.metric}_{args.conf}.pth.tar"
    state_dict = torch.load(model_path, map_location = "cpu")
    model.load_state_dict(state_dict['state_dict'])

model.cuda()
    
txtfile=save_dir + sub_folder + f'lmd{args.lmd}_{args.method}_{args.conf}.txt'
if os.path.exists(txtfile):
    os.system('rm %s' % txtfile)

with open(txtfile, "a") as myfile:
    myfile.write('epoch: train_acc test_acc best_acc \n')

epoch=0
train_acc = 0
best_acc_ = 0.0
lmd = args.lmd
for epoch in range(args.n_epoch):
# train models
    adjust_learning_rate(optimizer, epoch, alpha_plan, loss_type=args.loss)
    model.train()
    train_acc, lmd = train(epoch,num_classes,train_loader,peer_loader_x,peer_loader_y, model, optimizer,args.loss, lmd)

# evaluate models
    test_acc, best_acc_ = evaluate(test_loader=test_loader, save=True, model=model,epoch=epoch,best_acc_=best_acc_,args=args)
# save results
    print(f'[Epoch {epoch}] lmd: {lmd}')
    print('train acc on train images is ', train_acc)
    print('test acc on test images is ', test_acc)
    print('best test acc on test images is ', best_acc_)

    with open(txtfile, "a") as myfile:
        myfile.write(str(int(epoch)) + ': '  + str(train_acc) +' ' + str(test_acc) +' ' + str(best_acc_) + "\n")
print("--- %s seconds ---" % (time.time() - start_time))
exit()

