import os.path as osp
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import loss

from util import *

def warmup_prn(args, writer, pseudo_optimizer, dset_loaders, prn, target_netF, target_netB, target_netC, before_warmup_acc):
    # pseudo_labeler warmup
    target_netF.eval()
    target_netB.eval()
    target_netC.eval()
    prn.eval()
    
    print("Warm up prn...")

    
    iter_per_epoch  = len(dset_loaders["target"])
    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_per_epoch  = len(dset_loaders["target"])
    warmup_max_iter = args.warmup_epoch * len(dset_loaders["target"])

    prn.train()

    # pseudo net training
    for epoch in range(args.warmup_epoch):
        for iter_num in range(epoch * iter_per_epoch, (epoch + 1) * iter_per_epoch):

            try:
                inputs_test, logits_test, preds_test = iter_test.next()
            except:
                iter_test = iter(dset_loaders["target"])
                inputs_test, logits_test, preds_test = iter_test.next()

            inputs_test = inputs_test.cuda()

            # align labels
            beta = prn(preds_test.cuda()) 

            #warmup consistency loss
            kl_loss = torch.mean(torch.sum(loss.KL(preds_test.cuda(),beta,d=-1), dim=-1))

            warmup_loss =  kl_loss 

            writer.add_scalar("Warmup/warmup total loss", warmup_loss, iter_num)         

            pseudo_optimizer.zero_grad()
            warmup_loss.backward()
            pseudo_optimizer.step()

            # train prn 
            writer.add_scalar("Warmup/pseudo lr", pseudo_optimizer.param_groups[0]['lr'], iter_num)

            if iter_num % interval_iter == 0 or iter_num == warmup_max_iter - 1:
                prn.eval()

                log_str = 'Task: {}, Iter:{}/{} (PRN Warmup); [Train] warmup_loss = {:.2f} '.format(
                        args.name, iter_num, warmup_max_iter,warmup_loss)

                print(log_str)
                args.out_file.write(log_str + '\n')
                args.out_file.flush()
                
            
                acc_s_te = cal_acc_pseudo(args, dset_loaders['test'],  False, True, prn)
                
                try:
                    for  i in range(len(acc_s_te)):
                        name = '{} to {}'.format(args.src[i], args.tar[0])
                        log_str = 'Task: {}, PRN evaluation (After warmup) [Test] Accuracy = {:.2f}% -> {:.2f}%  '.format(name, before_warmup_acc[i],  acc_s_te[i]  )
                        print(log_str)

                        args.out_file.write(log_str + '\n')
                        args.out_file.flush()
                        writer.add_scalar("PRN/test accuracy-{}".format(args.src[i]), acc_s_te[i] , 0)
                except:
                    log_str = 'Task: {}, PRN evaluation (After warmup) [Test] Accuracy = {:.2f}% -> {:.2f}% '.format(args.name, before_warmup_acc,  acc_s_te )

                    print(log_str)
                    args.out_file.write(log_str + '\n')
                    args.out_file.flush()
                    writer.add_scalar("PRN/test accuracy", after_warmup_acc, 0)
                prn.train()

    if args.warmup_epoch == 0:
        print("Pass warmup")
        acc_s_te = before_warmup_acc
        
    
    #pseudo labeler evaluation (after warmup)
    target_netF.eval()
    target_netB.eval()
    target_netC.eval()
    prn.eval()

    return acc_s_te


def train_target(args, writer, optimizer, pseudo_optimizer, dset_loaders, prn, target_netF, target_netB, target_netC, after_warmup_acc):

    iter_per_epoch  = len(dset_loaders["target"])
    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_per_epoch  = len(dset_loaders["target"])
    
    for epoch in range(args.max_epoch):
        
        # target model training 
        target_netF.train()
        target_netB.train()
        target_netC.train()
        prn.eval()
        
        for iter_num in range(epoch * iter_per_epoch, (epoch + 1) * iter_per_epoch):
            try:
                inputs_test, logits_test, preds_test = iter_test.next()
            except:
                iter_test = iter(dset_loaders["target"])
                inputs_test, logits_test, preds_test = iter_test.next()

            if inputs_test.size(0) == 1:
                continue

            inputs_test = inputs_test.cuda()

            static_lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

            lambda_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, gamma=args.gamma)

            if args.pseudo_train_epoch_iter > 0:
                if iter_num // (args.pseudo_train_epoch_iter *iter_per_epoch) ==0:
                    init_op_copy(pseudo_optimizer, args)
                dynamic_lr_scheduler(pseudo_optimizer, iter_num=iter_num, max_iter=args.pseudo_train_epoch_iter *iter_per_epoch)
            else:
                static_lr_scheduler(pseudo_optimizer, iter_num=iter_num, max_iter=max_iter)

            #train target model
            features_test = target_netB(target_netF(inputs_test))
            pred = target_netC(features_test)
            pred_prob = nn.Softmax(dim=-1)(pred)
            lambda_ = optimizer.param_groups[0]['lambda_']

            # align labels
            with torch.no_grad():
                
                #cross attention
                if args.cross_attention:
                    beta = prn(preds_test.cuda(), pred_prob.unsqueeze(1).repeat(1, args.src_num, 1))
                else:
                    beta = prn(preds_test.cuda())
                    
                optimal_pseudo = torch.mean(beta, dim=1)
                
            #losses over target model
            kl_loss = torch.mean(loss.KL(optimal_pseudo, pred_prob,d=1))
            
            entropy_loss = torch.mean(loss.Entropy(pred_prob, d=-1))
            msoftmax = pred_prob.mean(dim=0)
            gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
            entropy_loss -= gentropy_loss
                        
            tar_loss = lambda_ * kl_loss +  entropy_loss 
            
            writer.add_scalar("Target model/ kl loss", kl_loss, iter_num)
            writer.add_scalar("Target model/ ent loss", entropy_loss, iter_num)
            
            
            optimizer.zero_grad()
            tar_loss.backward()
            optimizer.step()

            writer.add_scalar("Target model/lr", optimizer.param_groups[0]['lr'], iter_num)
                        
            if iter_num % interval_iter == 0 or iter_num == max_iter - 1:
                target_netF.eval()
                target_netB.eval()
                target_netC.eval()
                prn.eval()
                
                log_str = '**Target model training** Task: {}, Iter:{}/{} (Target model); [Train] tar_loss = {:.2f} kl_loss = {:.2f} ent_loss = {:.2f} lambda ={:.2f} '.format(
                        args.name, iter_num, max_iter, tar_loss, kl_loss, entropy_loss, lambda_)

                print(log_str)
                args.out_file.write(log_str + '\n')
                args.out_file.flush()

                for name, param in prn.named_parameters():
                    writer.add_histogram(name, param.clone().cpu().data.numpy(), iter_num)

                acc_s_te, ent_loss, classifier_loss, = cal_acc(args, dset_loaders['test'], target_netF, target_netB, target_netC, False)
            
                try:

                    for  i in range(len(acc_s_te)):
                        name = '{} to {}'.format(args.src[i], args.tar[0])
                        log_str = '[Test] Task: {},   Accuracy = {:.2f}% entropy loss = {:.5f} ce loss= {:.5f}'.format(
                        name, acc_s_te[i], ent_loss[i],classifier_loss[i])                            
                        print(log_str)
                        args.out_file.write(log_str + '\n')
                        args.out_file.flush()

                        writer.add_scalar("Target model/test accuracy-{}".format(args.src[i]), acc_s_te[i], iter_num)
                        writer.add_scalar("Target model/test kl loss-{}".format(args.src[i]), classifier_loss[i], iter_num)
                        writer.add_scalar("Target model/test im loss-{}".format(args.src[i]), ent_loss[i], iter_num)
                        
                except:
                    log_str = '[Test] Task: {},   Accuracy = {:.2f}% entropy loss = {:.5f} ce loss= {:.5f}'.format(
                        args.name, acc_s_te, ent_loss, classifier_loss)                            
                    print(log_str)
                    args.out_file.write(log_str + '\n')
                    args.out_file.flush()
                    
                    writer.add_scalar("Target model/test accuracy", acc_s_te, iter_num)
                    writer.add_scalar("Target model/test kl loss", classifier_loss, iter_num)
                    writer.add_scalar("Target model/test im loss", ent_loss, iter_num)

                acc_s_te = cal_acc_pseudo(args, dset_loaders['test'],  False, True, prn)

                target_netF.train()
                target_netB.train()
                target_netC.train()
                prn.eval()
                
        # pseudo label refinement
        if args.pseudo_train_epoch_iter >0: 
            if epoch % args.pseudo_train_epoch_iter == 0:
                
                target_netF.eval()
                target_netB.eval()
                target_netC.eval()
                prn.train()
                
                
                for iter_num in range(epoch * iter_per_epoch, (epoch + args.refine_epoch) * iter_per_epoch):

                    try:
                        inputs_test, logits_test, preds_test = iter_test.next()
                    except:
                        iter_test = iter(dset_loaders["target"])
                        inputs_test, logits_test, preds_test = iter_test.next()

                    if inputs_test.size(0) == 1:
                        continue

                    inputs_test = inputs_test.cuda()

                    static_lr_scheduler(pseudo_optimizer, iter_num=iter_num, max_iter=max_iter)
                    
                    with torch.no_grad():
                        features_test = target_netB(target_netF(inputs_test))
                        pred = target_netC(features_test)
                        pred_prob = nn.Softmax(dim=-1)(pred)
                    
                    #cross attention
                    if args.cross_attention:
                        beta = prn(preds_test.cuda(), pred_prob.unsqueeze(1).repeat(1, args.src_num, 1))
                    else:
                        beta = prn(preds_test.cuda())

                    
                    #losses over pseudo net

                    #source division and find representative prediction --> const_src, conf_rep_pred, iconf_rep_pred, iconst_src, is_iconst
                    max_probs, max_cls = beta.max(dim=-1)
                    agree_cls, agree_src = torch.mode(max_cls, dim=-1)
                    agree_cls = agree_cls.unsqueeze(1).repeat([1, max_cls.shape[-1]])   #(batch, num_src)
                    multi_agree = (agree_cls == max_cls).sum(dim=-1)>1  #samples with multi agreement
                    conf_src = max_probs.argmax(dim=-1) # source index with max prob
                    
                    #confident source group
                    # predicted class of representative pred (batch, 1)
                    conf_cls = torch.gather(max_cls, 1, conf_src.unsqueeze(-1).expand(max_cls.size(0),1)) # representative pred with max probs
                    conf_cls[multi_agree]=agree_cls[:,0][multi_agree].unsqueeze(1)            # change representative pred if there are multiple agreement
                    
                    const_src = (conf_cls.expand(max_cls.size(0), max_cls.size(1))==max_cls)     # source map that belongs to the confident src group
                    
                    is_multiple_const = torch.where(const_src.sum(axis=-1)>1)[0]        # when multiple source are consistent

                    #inconfident source group
                    iconst_src =~const_src*1           # source map that belongs to the unconfident src group
                    is_iconst = torch.where(iconst_src.sum(axis=-1)!=0)[0]   # is there are samples with the unconfident source groups defined?
                    is_multiple_iconst = torch.where(iconst_src.sum(axis=-1)>1)[0] # when multiple preds are in the unconfident source group?

                    const_src = const_src*1             # source map that belongs to the confident src group
                    conf_src[multi_agree] =(max_probs*const_src).argmax(dim=-1)[multi_agree]  # source idx of representative pred in confident group
                    
                    # max prob in unconfident groups 
                    max_prob_iconst =max_probs*iconst_src
                    iconst_src_rep = max_prob_iconst.argmax(dim=-1)

                    # source idx of representative pred in unconfident group
                    iconst_src_rep[is_multiple_iconst] = max_prob_iconst.topk(k=2, largest=False, dim=-1)[1][:,-1][is_multiple_iconst]
                    iconst_src_rep = iconst_src_rep[is_iconst]


                    #representative pred for confident group
                    conf_rep_pred = torch.gather(beta, 1, conf_src.cuda().unsqueeze(1).unsqueeze(2).expand(beta.size(0), 1, beta.size(2)))

                    #representative pred for unconfident group
                    iconf_rep_pred = torch.gather(beta[is_iconst], 1, iconst_src_rep.cuda().unsqueeze(1).unsqueeze(2).expand(beta[is_iconst].cuda().size(0), 1, beta.size(2)))
                    
                    
                    #calculate loss
                    conf_conc, iconf_conc,  conf_iconf_dist = 0.0 ,  0.0,  0.0
                    
                    pred_compare = beta*const_src.unsqueeze(2).expand(const_src.size(0), const_src.size(1), beta.size(2))
                    conf_rep_pred_compare  = conf_rep_pred.expand(conf_rep_pred.size(0),beta.size(1), conf_rep_pred.size(2))
                    pred_compare +=conf_rep_pred_compare.clone()*iconst_src.unsqueeze(2).expand(iconst_src.size(0), iconst_src.size(1), beta.size(2))
                    conf_conc = torch.mean(torch.sum(loss.KL(conf_rep_pred_compare[is_multiple_const].detach(), pred_compare[is_multiple_const],   d=-1 ),dim=-1),dim=-1)

                    #unconfident group concentration
                    if len(is_iconst)>0:
                        pred_compare = beta*iconst_src.unsqueeze(2).expand(const_src.size(0), const_src.size(1), beta.size(2))
                        iconf_rep_pred_compare  = iconf_rep_pred.expand(iconf_rep_pred.size(0),beta.size(1), iconf_rep_pred.size(2))
                        pred_compare = pred_compare[is_iconst]+iconf_rep_pred_compare*const_src[is_iconst].unsqueeze(2).expand(const_src[is_iconst].size(0), const_src[is_iconst].size(1), beta[is_iconst].size(2))
                        iconf_conc = torch.mean(torch.sum(loss.KL(iconf_rep_pred_compare.detach(), pred_compare,   d=-1 ),dim=-1),dim=-1)
                        
                        #conf - unconf dist
                        conf_iconf_dist = torch.mean( loss.KL(conf_rep_pred[is_iconst].detach(), iconf_rep_pred))
                    

                    lambda_s, lambda_i, lambda_d =args.lambda_s,  args.lambda_u, args.lambda_d
                    
                    
                    kl_loss = torch.mean(torch.sum(loss.KL(preds_test.cuda(),beta,d=-1), dim=-1))

                    pseudo_loss = conf_conc + lambda_i * iconf_conc +lambda_d* conf_iconf_dist + lambda_s* kl_loss 

                    
                    writer.add_scalar("Pseudo labeler/ pseudo kl loss", kl_loss, iter_num)
                    writer.add_scalar("Pseudo labeler/ confident concentration loss", conf_conc, iter_num)
                    writer.add_scalar("Pseudo labeler/ unconfident concentration loss", iconf_conc, iter_num)
                    writer.add_scalar("Pseudo labeler/ conf - unconf dist", conf_iconf_dist, iter_num)
                    writer.add_scalar("Pseudo labeler/ total loss", pseudo_loss, iter_num)      
                    
                    pseudo_optimizer.zero_grad()
                    pseudo_loss.backward()
                    pseudo_optimizer.step()

                    writer.add_scalar("Pseudo labeler/pseudo lr", pseudo_optimizer.param_groups[0]['lr'], iter_num)
                    
                    if iter_num % interval_iter==0 or iter_num == (epoch + 1) * iter_per_epoch- 1:
                        target_netF.eval()
                        target_netB.eval()
                        target_netC.eval()
                        prn.eval()

                        log_str = '**Pseudo net refinement** Epoch:{}/{} [Train] pseudo_loss = {:.2f} kl_loss = {:.2f} conf_conc = {:.2f} iconf_conc = {:.2f} conf_iconf_dist = {:.2f}  '.format(
                                epoch+1, args.max_epoch, pseudo_loss, kl_loss,  conf_conc, iconf_conc, conf_iconf_dist)

                        print(log_str)
                        args.out_file.write(log_str + '\n')
                        args.out_file.flush()


                        for name, param in prn.named_parameters():
                            writer.add_histogram(name, param.clone().cpu().data.numpy(), iter_num)

                        try:
                            for  i in range(len(acc_s_te)):
                                name = '{} to {}'.format(args.src[i], args.tar[0])
                                log_str = 'Pseudo lableler evaluation (During adaptation) [Test] Accuracy = {:.2f}% -> {:.2f}%  '.format( after_warmup_acc[i],  acc_s_te[i]  )
                                print(log_str)

                                args.out_file.write(log_str + '\n')
                                args.out_file.flush()
                                writer.add_scalar("Pseudo labeler/test accuracy-{}".format(args.src[i]), acc_s_te[i] , iter_num)
                        except:
                            log_str = 'Task: {}, Pseudo labeler evaluation (During adaptation) [Test] Accuracy = {:.2f}% -> {:.2f}% '.format(args.name, after_warmup_acc,  acc_s_te )

                            print(log_str)
                            args.out_file.write(log_str + '\n')
                            args.out_file.flush()
                            writer.add_scalar("Pseudo labeler/test accuracy", after_warmup_acc,iter_num )

                        target_netF.eval()
                        target_netB.eval()
                        target_netC.eval()
                        prn.train()
                        
    
    if args.issave:
        torch.save(target_netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename +".pt"))
        torch.save(target_netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename +".pt"))
        torch.save(target_netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename +".pt"))
        torch.save(prn.state_dict(), osp.join(args.output_dir, "prn_" + args.savename +".pt"))


