import random
import sys
import argparse
import os.path as osp
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from scipy.spatial.distance import cdist

sys.path.append(osp.abspath('.'))
from dalib.adaptation.proto import CrossEntropyLabelSmooth, ProtoLoss
from dalib.modules.entropy import entropy
from common.utils.metric import accuracy
from common.utils.meter import AverageMeter

import scripts.utils as utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def obtain_label(loader, model, args):
    all_feature = []
    all_output = []
    all_label = []
    with torch.no_grad():
        for inputs, labels in tqdm(loader, total=len(loader), desc="get target pseudo labels"):
            outputs, feature = model(inputs.cuda())
            all_feature.append(feature.cpu())
            all_output.append(outputs.cpu())
            all_label.append(labels)

    all_fea = torch.cat(all_feature, 0)
    all_output = torch.cat(all_output, 0)
    all_label = torch.cat(all_label, 0)
    all_output = nn.Softmax(dim=1)(all_output)
    _, predict = torch.max(all_output, 1)

    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    if args.distance == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()
    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)
    initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count>args.threshold)
    labelset = labelset[0]
    # print(labelset)

    dd = cdist(all_fea, initc[labelset], args.distance)
    pred_label = dd.argmin(axis=1)
    pred_label = labelset[pred_label]

    for round in range(1):
        aff = np.eye(K)[pred_label]
        initc = aff.transpose().dot(all_fea)
        initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
        dd = cdist(all_fea, initc[labelset], args.distance)
        pred_label = dd.argmin(axis=1)
        pred_label = labelset[pred_label]

    acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
    print('Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100))

    return pred_label.astype('int')


def train(train_source_iter, train_target_iter, model, proto_criterion, optimizer, lr_scheduler, epoch, args):
    # source_epoch = (epoch < args.source_epoch)
    losses = AverageMeter('Total Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Source Acc', ':3.1f')

    # if not source_epoch:
    #    for k, v in model.head.named_parameters():
    #        v.requires_grad = False
        # model.eval()
        # mem_label = obtain_label(train_target_iter, model, args)
        # mem_label = torch.from_numpy(mem_label).cuda()

    # switch to train mode
    model.train()
    for i in tqdm(range(args.iters_per_epoch), total=args.iters_per_epoch):

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device)
        x_t, _ = next(train_target_iter)
        x_t = x_t.to(device)
        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        # CrossEntropyLabelSmooth(num_classes=args.num_classes, epsilon=args.smooth)(y_s, labels_s)
        
        entropy_loss = entropy(torch.softmax(y_t, dim=1)).mean()
        cls_loss += entropy_loss * args.ent_weight
        
        prototypes = model.head.weight.data.detach()
        transfer_loss = proto_criterion(prototypes, f_t)
        
        loss = cls_loss + transfer_loss * args.trade_off
        cls_acc = accuracy(y_s, labels_s)[0]
        cls_accs.update(cls_acc, x_s.size(0))
        losses.update(loss.item(), x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_t.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    print(f"Epoch: [{epoch}/{args.epochs}], {losses}, {trans_losses}, {cls_accs}")


def parse_args(parser):
    # dataset parameters
    parser.add_argument('--root', metavar='DIR', default='/home/username/datasets',
                        help='root path of dataset')
    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
                             ' (default: Office31)')
    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
    parser.add_argument('--train-resizing', type=str, default='default')
    parser.add_argument('--val-resizing', type=str, default='default')
    parser.add_argument('--resize-size', type=int, default=224,
                        help='the image size after resizing')
    parser.add_argument('--no-hflip', action='store_true',
                        help='no random horizontal flipping during training')
    parser.add_argument('--norm-mean', type=float, nargs='+',
                        default=(0.485, 0.456, 0.406), help='normalization mean')
    parser.add_argument('--norm-std', type=float, nargs='+',
                        default=(0.229, 0.224, 0.225), help='normalization std')
    # model parameters
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                        choices=utils.get_model_names(),
                        help='backbone architecture: ' +
                             ' | '.join(utils.get_model_names()) +
                             ' (default: resnet18)')
    parser.add_argument('--bottleneck-dim', default=256, type=int,
                        help='Dimension of bottleneck')
    parser.add_argument('--no-pool', action='store_true',
                        help='no pool layer after the feature extractor.')
    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
    parser.add_argument('--source_epoch', type=int, default=5)
    parser.add_argument('--smooth', type=float, default=0.1, help='label smoothing for source training.')
    parser.add_argument('--threshold', type=int, default=0)
    parser.add_argument('--ent_weight', type=float, default=0.0)
    parser.add_argument('--nav_t', type=float, default=1.0)
    parser.add_argument('--s_par', type=float, default=0.5)
    parser.add_argument("--assign_type", type=str, default='prob')
    parser.add_argument("--cost_type", type=str, default='cos')
    parser.add_argument("--balance_type", type=str, default='proto')
    parser.add_argument('--trade-off', default=1., type=float,
                        help='the trade-off hyper-parameter for transfer loss')
    # training parameters
    parser.add_argument('-b', '--batch-size', default=32, type=int,
                        metavar='N', help='mini-batch size')
    parser.add_argument('-lr', '--learning-rate', default=0.01, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('-wd', '--weight-decay', default=1e-3, type=float,
                        metavar='W', help='weight decay (default: 1e-3)',
                        dest='weight_decay')
    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
                        help='number of data loading workers (default: 2)')
    parser.add_argument('--epochs', default=20, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
                        help='Number of iterations per epoch')
    parser.add_argument('-p', '--print-freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 100)')
    parser.add_argument('--seed', default=12345, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--per-class-eval', action='store_true',
                        help='whether output per-class accuracy during evaluation')
    parser.add_argument("--log", type=str, default='/tmp/DAmetric_lib_logs/cdan',
                        help="Where to save logs, checkpoints and debugging images.")
    parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
                        help="When phase is 'test', only test the model."
                             "When phase is 'analysis', only analysis the model.")
    return parser.parse_args()
