import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import loss
from torch.utils.data import DataLoader
from data_list import  ImageList_idx, ImageList
import  math
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix

#data processing
def data_load(args, netF_list=[], netC_list=[]): 
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size

    if len(netF_list) >0:

        txt_tar = open(args.t_dset_path).readlines()
        txt_test = open(args.test_dset_path).readlines()
        
        dsets["target"] = ImageList_idx(args, txt_tar, netF_list, netC_list,  transform=image_train())
        dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
        dsets["test"] = ImageList_idx(args, txt_test, netF_list, netC_list, transform=image_test())
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
    else:
        txt_src = open(args.s_dset_path).readlines()
        count = np.zeros(args.class_num)
        tr_txt = []
        te_txt = []
        for i in range(len(txt_src)):
            line = txt_src[i]
            reci = line.strip().split(' ')
            if count[int(reci[1])] < 3:
                count[int(reci[1])] += 1
                te_txt.append(line)
            else:
                tr_txt.append(line)
        txt_test = open(args.test_dset_path).readlines()

        dsets["source_tr"] = ImageList(args, tr_txt, transform=image_train())
        dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
        dsets["source_te"] = ImageList(args,te_txt, transform=image_test())
        dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
        dsets["test"] = ImageList(args, txt_test, transform=image_test())
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
    return dset_loaders

def image_train(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

def image_test(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        normalize
    ])


#optimizer
def op_copy(optimizer, args):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
        param_group['lambda_'] = 1.0

    return optimizer

def init_op_copy(optimizer, args):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = args.lr
    return optimizer

def set_requires_grad(model, requires_grad=True):
    for param in model.parameters():
        param.requires_grad = requires_grad
        
# learning rate scheduler
def dynamic_lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    dynamic_iter_num = iter_num %max_iter
    decay = (1 + gamma * dynamic_iter_num/ max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
        param_group['weight_decay'] = 1e-3
        param_group['momentum'] = 0.9
        param_group['nesterov'] = True
    return optimizer

def static_lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
        param_group['weight_decay'] = 1e-3
        param_group['momentum'] = 0.9
        param_group['nesterov'] = True
    return optimizer


# hyperparameter scheduler
def lambda_scheduler(optimizer, iter_num,  max_iter, gamma=1.0):
    lambda_ = gamma* math.exp(- iter_num/max_iter)
    for param_group in optimizer.param_groups:
        param_group['lambda_'] =lambda_
    return optimizer


#testing option
def cal_acc(args, loader, netF, netB, netC, flag=False, pseudo_labeler=False, prn=None):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            if pseudo_labeler:
                output = torch.zeros([len(inputs), len(netC), netC[0].fc.out_features])
                for i in range(len(args.src)):
                    outputs_test = netC[i](netF[i](inputs))
                    softmax_out = nn.Softmax(dim=1)(outputs_test)
                    output[:, i, :] = softmax_out  # (batch, num_src, num_cls)
                outputs = prn(output.cuda())
            else:
                if 'list' in str(type(netF)):
                    output = torch.zeros([len(inputs), len(netC), netC[0].fc.out_features])
                    for i in range(len(args.src)):
                        outputs_test = netC[i](netF[i](inputs))
                        softmax_out = nn.Softmax(dim=-1)(outputs_test)
                        output[:, i, :] = softmax_out 
                    outputs =output
                else:
                    pred = netC(netF(inputs))
                    pred_prob = nn.Softmax(dim=-1)(pred)
                    outputs =pred_prob
                

            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                if len(all_output.shape) == 3:
                    all_label = np.tile(all_label[:, np.newaxis], [1, all_output.shape[1]])
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                labels = labels.float()
                if len(all_output.shape) == 3:
                    labels = np.tile(labels[:, np.newaxis], [1, all_output.shape[1]])
                all_label = np.concatenate((all_label,labels), 0)
    
    all_label = torch.from_numpy(all_label)
    _, predict = torch.max(all_output, -1)
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label, dim=0) / float(all_label.size()[0])
    
    all_label = torch.nn.functional.one_hot(all_label.to(torch.int64), num_classes =all_output.shape[-1])

    if len(all_output.shape) == 3:
        all_label =all_label.unsqueeze(1)
    else:
        all_label =all_label

    classifier_loss = torch.mean(loss.KL(all_label,all_output,d=-1)).cpu().data
    mean_ent = torch.mean(loss.Entropy(all_output, d=-1), dim=0).cpu().data

    if flag:
        matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
        acc = matrix.diagonal()/matrix.sum(axis=1) * 100
        acc = acc.mean()
        aa = [str(np.round(i, 2)) for i in acc]
        acc = ' '.join(aa)
        return aa, acc
    else:
        return accuracy*100, mean_ent, classifier_loss

# testing pseudo acc
def cal_acc_pseudo(args, loader, flag=False, pseudo_labeler=False, pseudo_net=None):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            preds = data[2]

            inputs = inputs.cuda()
            if pseudo_labeler:
                outputs = pseudo_net(preds.cuda())
            else:
                outputs = preds

            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                if len(all_output.shape) == 3:
                    all_label = np.tile(all_label[:, np.newaxis], [1, all_output.shape[1]])
                #all_label = torch.from_numpy(all_label)
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                labels = labels.float()
                if len(all_output.shape) == 3:
                    labels = np.tile(labels[:, np.newaxis], [1, all_output.shape[1]])
                all_label = np.concatenate((all_label,labels), 0)
    
    all_label = torch.from_numpy(all_label)
    _, predict = torch.max(all_output, -1)
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label, dim=0) / float(all_label.size()[0])
    
    all_label = torch.nn.functional.one_hot(all_label.to(torch.int64), num_classes =all_output.shape[-1])

    if len(all_output.shape) == 3:
        all_label =all_label.unsqueeze(1)
    else:
        all_label =all_label

    if flag:
        matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
        acc = matrix.diagonal()/matrix.sum(axis=1) * 100
        acc = acc.mean()
        aa = [str(np.round(i, 2)) for i in acc]
        acc = ' '.join(aa)
        return aa, acc
    else:
        return accuracy*100
