from __future__ import print_function
import sys

sys.path.append('../../')
sys.path.append('../../pipeline')
sys.path.append('../../utils')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import os
import argparse
import numpy as np
from math import log2
from Contrastive_loss import *

from dataloader_blurred import blurred_dataloader, get_input_size
from PreResNet_blurred import *
from default_config import get_exp_dict, window_time_dict, slide_time_dict

## For plotting the logs
# import wandb
# wandb.init(project="noisy-label-project", entity="..")
num_threads = '16'
torch.set_num_threads(int(num_threads))
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads

## Arguments to pass
parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--batch_size', default=64, type=int, help='train batchsize')
parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--noise_mode', default='sym')
parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=30, type=float, help='weight for unsupervised loss')
parser.add_argument('--lambda_c', default=0.025, type=float, help='weight for contrastive loss')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--d_u', default=0.7, type=float)
parser.add_argument('--tau', default=5, type=float, help='filtering coefficient')
parser.add_argument('--metric', type=str, default='JSD', help='Comparison Metric')
parser.add_argument('--seed', default=123)
parser.add_argument('--cuda_dev', default=0, type=int)
parser.add_argument('--resume', default=False, type=bool, help='Resume from the warmup checkpoint')
parser.add_argument('--data_path', default='./data/cifar10', type=str, help='path to dataset')
parser.add_argument('--dataset', default='cifar10', type=str)

parser.add_argument('--out_path', type=str, default='./output', help='Output path for result')
parser.add_argument('--database_save_dir', type=str, default='./data/CL_database/',
                    help='Should give a path to load the database of one patient.')
parser.add_argument('--data_name', type=str, default='fNIRS_2',
                    help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
parser.add_argument('--exp_id', type=int, default=3,
                    help='The experimental id.')
parser.add_argument('--num_classes', type=int, default=2, help='Number of in-distribution classes')
parser.add_argument('--noise_ratio', type=float, default=0.4, help='percent of noise')
parser.add_argument('--window_time', type=float, default=1,
                    help='The seconds of every sample segment.')
parser.add_argument('--slide_time', type=float, default=0.5,
                    help='The sliding seconds between two sample segments.')
parser.add_argument('--num_level', type=int, default=5,
                    help='The number of levels.')
parser.add_argument('--num_samples', type=int, default=73440, help='number of samples')
parser.add_argument('--patience', type=int, default=10, help='patience fot early stopping')

args = parser.parse_args()

exp_dict = get_exp_dict(args.data_name)
exp_patient_list = exp_dict[args.exp_id]
args.train_patient_list = exp_patient_list[0]
args.valid_patient_list = exp_patient_list[1]
args.test_patient_list = exp_patient_list[2]

args.window_time = window_time_dict[args.data_name]
args.slide_time = slide_time_dict[args.data_name]

if args.exp_id == 12 and (args.noise_ratio == 0.2 or args.noise_ratio == 0.4):
    exit(1)

## GPU Setup
torch.cuda.set_device(args.cuda_dev)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

## result path
res_path = os.path.join(args.out_path, 'UNICON_experiment_2', args.data_name, str(int(args.noise_ratio * 100)),
                        f'exp{args.exp_id}')
if not os.path.isdir(res_path):
    os.makedirs(res_path)

## Checkpoint Location
# folder = args.data_name + '_' + args.noise_mode + '_' + str(args.noise_ratio) 
# model_save_loc = './checkpoint/' + folder
model_save_loc = res_path
if not os.path.exists(model_save_loc):
    os.mkdir(model_save_loc)

## Log files
stats_log = open(model_save_loc + '/%s_%.1f' % (args.data_name, args.noise_ratio) + '_stats.txt', 'w')
test_log = open(model_save_loc + '/%s_%.1f' % (args.data_name, args.noise_ratio) + '_acc.txt', 'w')
test_loss_log = open(model_save_loc + '/test_loss.txt', 'w')
train_acc = open(model_save_loc + '/train_acc.txt', 'w')
train_loss = open(model_save_loc + '/train_loss.txt', 'w')


# SSL-Training
def train(epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader):
    net2.eval()  # Freeze one network and train the other
    net.train()

    unlabeled_train_iter = iter(unlabeled_trainloader)
    num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1

    ## Loss statistics
    loss_x = 0
    loss_u = 0
    loss_scl = 0
    loss_ucl = 0

    for batch_idx, (inputs_x, inputs_x2, inputs_x3, inputs_x4, labels_x, w_x) in enumerate(labeled_trainloader):
        try:
            inputs_u, inputs_u2, inputs_u3, inputs_u4 = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            inputs_u, inputs_u2, inputs_u3, inputs_u4 = unlabeled_train_iter.next()

        batch_size = inputs_x.size(0)

        # Transform label to one-hot
        labels_x = torch.zeros(batch_size, args.num_classes).scatter_(1, labels_x.view(-1, 1), 1)
        w_x = w_x.view(-1, 1).type(torch.DoubleTensor)

        inputs_x, inputs_x2, inputs_x3, inputs_x4, labels_x, w_x = inputs_x.float().cuda(), inputs_x2.float().cuda(), inputs_x3.float().cuda(), inputs_x4.float().cuda(), labels_x.float().cuda(), w_x.float().cuda()
        inputs_u, inputs_u2, inputs_u3, inputs_u4 = inputs_u.float().cuda(), inputs_u2.float().cuda(), inputs_u3.float().cuda(), inputs_u4.float().cuda()

        with torch.no_grad():
            # Label co-guessing of unlabeled samples
            _, outputs_u11 = net(inputs_u)
            _, outputs_u12 = net(inputs_u2)
            _, outputs_u21 = net2(inputs_u)
            _, outputs_u22 = net2(inputs_u2)

            ## Pseudo-label
            pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21,
                                                                                                        dim=1) + torch.softmax(
                outputs_u22, dim=1)) / 4

            ptu = pu ** (1 / args.T)  ## Temparature Sharpening

            targets_u = ptu / ptu.sum(dim=1, keepdim=True)
            targets_u = targets_u.detach()

            ## Label refinement
            _, outputs_x = net(inputs_x)
            _, outputs_x2 = net(inputs_x2)

            px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2

            px = w_x * labels_x + (1 - w_x) * px
            ptx = px ** (1 / args.T)  ## Temparature sharpening

            targets_x = ptx / ptx.sum(dim=1, keepdim=True)
            targets_x = targets_x.detach()

        ## Unsupervised Contrastive Loss
        f1, _ = net(inputs_u3)
        f2, _ = net(inputs_u4)
        f1 = F.normalize(f1, dim=1)
        f2 = F.normalize(f2, dim=1)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        loss_simCLR = contrastive_criterion(features)

        # MixMatch
        l = np.random.beta(args.alpha, args.alpha)
        l = max(l, 1 - l)
        all_inputs = torch.cat([inputs_x3, inputs_x4, inputs_u3, inputs_u4], dim=0)
        all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)

        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        ## Mixup
        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        _, logits = net(mixed_input)
        logits_x = logits[:batch_size * 2]
        logits_u = logits[batch_size * 2:]

        ## Combined Loss
        Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size * 2], logits_u, mixed_target[batch_size * 2:],
                                 epoch + batch_idx / num_iter, warm_up)

        ## Regularization
        prior = torch.ones(args.num_classes) / args.num_classes
        prior = prior.cuda()
        pred_mean = torch.softmax(logits, dim=1).mean(0)
        penalty = torch.sum(prior * torch.log(prior / pred_mean))

        ## Total Loss
        loss = Lx + lamb * Lu + args.lambda_c * loss_simCLR + penalty

        ## Accumulate Loss
        loss_x += Lx.item()
        loss_u += Lu.item()
        # loss_ucl += loss_simCLR.item()

        # Compute gradient and Do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    sys.stdout.write('\r')
    sys.stdout.write(
        '%s:%.1f-%2d | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.2f  Unlabeled loss: %.2f Contrastive Loss:%.4f'
        % (args.data_name, args.noise_ratio, args.exp_id, epoch, args.num_epochs, batch_idx + 1, num_iter,
           loss_x / (batch_idx + 1), loss_u / (batch_idx + 1), loss_ucl / (batch_idx + 1)))
    sys.stdout.flush()


## For Standard Training 
def warmup_standard(epoch, net, optimizer, dataloader):
    net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1

    for batch_idx, (inputs, labels, path) in enumerate(dataloader):
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        _, outputs = net(inputs)
        loss = CEloss(outputs, labels)

        if args.noise_mode == 'asym':  # Penalize confident prediction for asymmetric noise
            penalty = conf_penalty(outputs)
            L = loss + penalty
        else:
            L = loss

        L.backward()
        optimizer.step()

    sys.stdout.write('\r')
    sys.stdout.write('%s:%d-%.1f| Epoch [%3d/%3d] Iter[%3d/%3d]\t CE-loss: %.4f'
                     % (args.data_name, args.exp_id, args.noise_ratio, epoch, args.num_epochs, batch_idx + 1, num_iter,
                        loss.item()))
    sys.stdout.flush()


## For Training Accuracy
def warmup_val(epoch, net, optimizer, dataloader):
    net.train()
    num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1
    total = 0
    correct = 0
    loss_x = 0

    with torch.no_grad():
        for batch_idx, (inputs, labels, path) in enumerate(dataloader):
            inputs, labels = inputs.cuda(), labels.cuda()
            optimizer.zero_grad()
            _, outputs = net(inputs)
            _, predicted = torch.max(outputs, 1)
            loss = CEloss(outputs, labels)
            loss_x += loss.item()

            total += labels.size(0)
            correct += predicted.eq(labels).cpu().sum().item()

    acc = 100. * correct / total
    print("\n| Train Epoch #%d\t Accuracy: %.2f%%\n" % (epoch, acc))

    train_loss.write(str(loss_x / (batch_idx + 1)))
    train_acc.write(str(acc))
    train_acc.flush()
    train_loss.flush()

    return acc


## Test Accuracy
def test(epoch, net1, net2):
    net1.eval()
    net2.eval()

    correct = 0
    total = 0
    loss_x = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            _, outputs1 = net1(inputs)
            _, outputs2 = net2(inputs)
            outputs = outputs1 + outputs2
            _, predicted = torch.max(outputs, 1)
            loss = CEloss(outputs, targets)
            loss_x += loss.item()

            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()

    acc = 100. * correct / total
    print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" % (epoch, acc))
    test_log.write(str(acc) + '\n')
    test_log.flush()
    test_loss_log.write(str(loss_x / (batch_idx + 1)) + '\n')
    test_loss_log.flush()
    return acc


# KL divergence
def kl_divergence(p, q):
    return (p * ((p + 1e-10) / (q + 1e-10)).log()).sum(dim=1)


## Jensen-Shannon Divergence 
class Jensen_Shannon(nn.Module):
    def __init__(self):
        super(Jensen_Shannon, self).__init__()
        pass

    def forward(self, p, q):
        m = (p + q) / 2
        return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)


## Calculate JSD
def Calculate_JSD(model1, model2, num_samples):
    JS_dist = Jensen_Shannon()
    JSD = torch.zeros(num_samples)

    for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
        inputs, targets = inputs.cuda(), targets.cuda()
        batch_size = inputs.size()[0]

        ## Get outputs of both network
        with torch.no_grad():
            out1 = torch.nn.Softmax(dim=1).cuda()(model1(inputs)[1])
            out2 = torch.nn.Softmax(dim=1).cuda()(model2(inputs)[1])

        ## Get the Prediction
        out = (out1 + out2) / 2

        ## Divergence clculator to record the diff. between ground truth and output prob. dist.  
        dist = JS_dist(out, F.one_hot(targets, num_classes=args.num_classes))
        JSD[int(batch_idx * batch_size):int((batch_idx + 1) * batch_size)] = dist

    return JSD


## Unsupervised Loss coefficient adjustment 
def linear_rampup(current, warm_up, rampup_length=16):
    current = np.clip((current - warm_up) / rampup_length, 0.0, 1.0)
    return args.lambda_u * float(current)


class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
        probs_u = torch.softmax(outputs_u, dim=1)
        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u) ** 2)
        return Lx, Lu, linear_rampup(epoch, warm_up)


class NegEntropy(object):
    def __call__(self, outputs):
        probs = torch.softmax(outputs, dim=1)
        return torch.mean(torch.sum(probs.log() * probs, dim=1))


def create_model(input_size):
    model = PreActResNet18(input_size=input_size, num_classes=args.num_classes, low_dim=128)
    model = model.cuda()
    return model


## Choose Warmup period based on Dataset
warm_up = 10

## Call the dataloader
# loader = dataloader.cifar_dataloader(args.data_name, r=args.noise_ratio, noise_mode=args.noise_mode,batch_size=args.batch_size,num_workers=4,\
#     root_dir=model_save_loc,log=stats_log, noise_file='%s/clean_%.4f_%s.npz'%(args.data_path,args.noise_ratio, args.noise_mode))
loader = blurred_dataloader(root_dir=model_save_loc, log=stats_log)
input_size, args.num_classes, args.num_samples = get_input_size(args)

print('| Building net')
net1 = create_model(input_size)
net2 = create_model(input_size)
cudnn.benchmark = True

## Semi-Supervised Loss
criterion = SemiLoss()

## Optimizer and Scheduler
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

scheduler1 = optim.lr_scheduler.CosineAnnealingLR(optimizer1, 280, 2e-4)
scheduler2 = optim.lr_scheduler.CosineAnnealingLR(optimizer2, 280, 2e-4)

## Loss Functions
CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
MSE_loss = nn.MSELoss(reduction='none')
contrastive_criterion = SupConLoss()

if args.noise_mode == 'asym':
    conf_penalty = NegEntropy()

## Resume from the warmup checkpoint 
model_name_1 = 'Net1_warmup.pth'
model_name_2 = 'Net2_warmup.pth'

if args.resume:
    start_epoch = warm_up
    net1.load_state_dict(torch.load(os.path.join(model_save_loc, model_name_1))['net'])
    net2.load_state_dict(torch.load(os.path.join(model_save_loc, model_name_2))['net'])
else:
    start_epoch = 0

best_acc = 0
early_stop_counter = 0

## Warmup and SSL-Training
for epoch in range(start_epoch, args.num_epochs + 1):
    test_loader = loader.run(args, 0, 'test')
    eval_loader = loader.run(args, 0, 'eval_train')
    warmup_trainloader = loader.run(args, 0, 'warmup')

    ## Warmup Stage 
    if epoch < warm_up:
        warmup_trainloader = loader.run(args, 0, 'warmup')

        print('Warmup Model')
        warmup_standard(epoch, net1, optimizer1, warmup_trainloader)

        print('\nWarmup Model')
        warmup_standard(epoch, net2, optimizer2, warmup_trainloader)

    else:
        ## Calculate JSD values and Filter Rate
        prob = Calculate_JSD(net2, net1, args.num_samples)
        threshold = torch.mean(prob)
        if threshold.item() > args.d_u:
            threshold = threshold - (threshold - torch.min(prob)) / args.tau
        SR = torch.sum(prob < threshold).item() / args.num_samples

        print('Train Net1\n')
        labeled_trainloader, unlabeled_trainloader = loader.run(args, SR, 'train', prob=prob)  # Uniform Selection
        train(epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader)  # train net1

        ## Calculate JSD values and Filter Rate
        prob = Calculate_JSD(net2, net1, args.num_samples)
        threshold = torch.mean(prob)
        if threshold.item() > args.d_u:
            threshold = threshold - (threshold - torch.min(prob)) / args.tau
        SR = torch.sum(prob < threshold).item() / args.num_samples

        print('\nTrain Net2')
        labeled_trainloader, unlabeled_trainloader = loader.run(args, SR, 'train', prob=prob)  # Uniform Selection
        train(epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader)  # train net1

    acc = test(epoch, net1, net2)
    scheduler1.step()
    scheduler2.step()

    if acc > best_acc:
        if epoch < warm_up:
            model_name_1 = 'Net1_warmup.pth'
            model_name_2 = 'Net2_warmup.pth'
            early_stop_counter = 0
        else:
            model_name_1 = 'Net1.pth'
            model_name_2 = 'Net2.pth'
            early_stop_counter = 0

        print("Save the Model-----")
        checkpoint1 = {
            'net': net1.state_dict(),
            'Model_number': 1,
            'Noise_Ratio': args.noise_ratio,
            'Loss Function': 'CrossEntropyLoss',
            'Optimizer': 'SGD',
            'Noise_mode': args.noise_mode,
            'Accuracy': acc,
            'Pytorch version': '1.4.0',
            'Dataset': 'CIFAR10',
            'Batch Size': args.batch_size,
            'epoch': epoch,
        }

        checkpoint2 = {
            'net': net2.state_dict(),
            'Model_number': 2,
            'Noise_Ratio': args.noise_ratio,
            'Loss Function': 'CrossEntropyLoss',
            'Optimizer': 'SGD',
            'Noise_mode': args.noise_mode,
            'Accuracy': acc,
            'Pytorch version': '1.4.0',
            'Dataset': 'CIFAR10',
            'Batch Size': args.batch_size,
            'epoch': epoch,
        }

        torch.save(checkpoint1, os.path.join(model_save_loc, model_name_1))
        torch.save(checkpoint2, os.path.join(model_save_loc, model_name_2))
        best_acc = acc
    else:
        if epoch >= warm_up:
            early_stop_counter += 1

    if early_stop_counter > args.patience:
        print(f'Validation acc score did not improve for {args.patience} epochs. Early stopping.')
        break
