# -*- coding: utf-8 -*-
"""

@author: Anonymous Author
"""

import numpy as np
import math
from copy import deepcopy

import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR

def get_output_emb(all_loader, model, classifier):
    
    model.eval()
    classifier.eval()
    totpre = []
    totemb = []
    # for cur_iter, (inputs, labels) in enumerate(all_loader):
    for cur_iter, itemitr in enumerate(all_loader):
        if cur_iter % 10 == 1:
            print(cur_iter)
        inputs, labels = itemitr['image'], itemitr['target']
        inputs = inputs.type(torch.cuda.FloatTensor)
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        with torch.no_grad():
            preds = model(inputs)
            preds = classifier(preds)
            emb = classifier.emb
            preds = F.softmax(preds, dim=1)
        if len(totpre) == 0:
            totpre = preds.cpu().numpy()
            totemb = emb.cpu().numpy().copy()
        else:
            totpre = np.vstack((totpre, preds.cpu().numpy() ))
            totemb = np.vstack((totemb, emb.cpu().numpy() ))
    
    model.train()
    classifier.train()
            
    return totpre, totemb

def get_grad_embedding(num_data, num_class, embDim, model, classifier, all_loader):# Y is model predictions for unlabeld samples
    
    model.eval()
    classifier.eval()
    embedding = np.zeros([num_data, embDim * num_class])
    
    tidx = 0
    with torch.no_grad():
        for idxs, (x, _) in enumerate(all_loader):
            x = x.cuda()
            cout = model(x)
            cout = classifier(cout)
            out = classifier.emb
            out = out.data.cpu().numpy()
            batchProbs = F.softmax(cout, dim=1).data.cpu().numpy()
            maxInds = np.argmax(batchProbs,1)
            
            for j in range(len(out)):
                for c in range(num_class):
                    if c == maxInds[j]:
                        embedding[tidx + j][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (1 - batchProbs[j][c])
                    else:
                        embedding[tidx + j][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (-1 * batchProbs[j][c])
            
            tidx += len(out)
                        
        return embedding

def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

def train_mlp(train_loader, model, classifier, args):
    
    optimizer =  torch.optim.SGD(
        classifier.parameters(),
        lr = args.lr,
        momentum = args.momentum,
        weight_decay = args.weight_decay,#0.0003,
        nesterov = True
    )
    
    num_epoch = args.train_eps
    
    scheduler = get_cosine_schedule_with_warmup( optimizer, 0, num_epoch*len(train_loader) )

    model.eval()
    classifier.train()
    
    CE = torch.nn.CrossEntropyLoss(reduction='mean')
    
    for ep in range(num_epoch):
        for cur_iter, itemitr in enumerate(train_loader):
            
            inputs, labels = itemitr['image'], itemitr['target']
            optimizer.zero_grad()
            
            inputs = inputs.type(torch.cuda.FloatTensor)
            inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
            # Perform the forward pass
            with torch.no_grad():
                feature = model(inputs)
            preds = classifier(feature)
            # Compute the loss
            loss = CE(preds, labels.long())#(preds, labels)
            # Perform the backward pass
    
            loss.backward()
            # Update the parametersSWA
            optimizer.step()
            scheduler.step()
            
            #print(ep, loss.item(), (preds.argmax(axis=1) == labels).sum() / 10 )

    # totpre, totemb = get_output_emb(all_loader, model, classifier)
    
    return model, classifier

def evaluation(test_loader, model, classifier):

    testpre_m, testl = [],[]
    classifier.eval()
    for cur_iter, itemitr in enumerate(test_loader):
        if cur_iter % 10 == 1:
            print(cur_iter)
        inputs, labels = itemitr['image'], itemitr['target']
        inputs = inputs.type(torch.cuda.FloatTensor)
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        with torch.no_grad():
            feature = model(inputs)
            preds = classifier(feature)
        if len(testpre_m) == 0:
            testpre_m = preds.cpu().numpy()
        else:
            testpre_m = np.vstack((testpre_m, preds.cpu().numpy() ))
        testl += labels.cpu().numpy().tolist()
    
    tspre = testpre_m.argmax(axis=1)    
    tacc = (tspre == np.array(testl)).sum() / len(testl)
    # acc += [tacc]
    print('test acc: ', tacc)  
    
    classifier.train()
    
    return tacc