import enum
import os
import copy
import pickle
from numpy.core.fromnumeric import cumprod
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from utils import *
from logger import Logger
import time
import numpy as np
import warnings
import pdb
# import clip
from clip import clip
from classes import *
from trainer import *
from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset
from Sinkhorn_distance import SinkhornDistance, SinkhornDistance_one_to_multi, SinkhornDistance_uniform
from loss.LDAMLoss import LDAMLoss
import ot
from thop import profile

eplisons = 0.1
fea_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8)

def load_clip_to_cpu(visual_backbone):
    backbone_name = visual_backbone
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url, os.path.expanduser("~/.cache/clip"))

    model = clip.load(model_path, 'cuda')

    return model[0]

def mixup_data(x, y, alpha=1.0, use_cuda=True):

    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam

def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype
        self.token_embedding = clip_model.token_embedding

    def forward(self, prompts, tokenized_prompts=None):
        
        if tokenized_prompts != None:
            x = prompts + self.positional_embedding.type(self.dtype)
            index = tokenized_prompts.argmax(dim=-1)
        else:
            x = self.token_embedding(prompts).type(self.dtype)  # [batch_size, n_ctx, d_model]
            x += self.positional_embedding.type(self.dtype)
            index = prompts.argmax(dim=-1)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        
        x = x[torch.arange(x.shape[0]), index] @ self.text_projection

        return x
    

def Sinkhorn(K, u, v):
    r = torch.ones_like(u)
    c = torch.ones_like(v)
    thresh = 1e-2
    for i in range(100):
        r0 = r
        r = u / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1)
        c = v / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1)
        err = (r - r0).abs().mean()
        if err.item() < thresh:
            break

    T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K

    return T

class model ():
    
    def __init__(self, config, data, test=False):
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.config = config
        self.training_opt = self.config['training_opt']
        self.model_opt = self.config['model']
        self.data = data
        self.test_mode = test
        self.num_gpus = torch.cuda.device_count()
        self.do_shuffle = config['shuffle'] if 'shuffle' in config else False
        self.clip_model = load_clip_to_cpu(self.model_opt['clip']['params']['visual_backbone'])
        
        self.img_num_list = class_count(self.data['train'])        
        
        self.ot_criterion = SinkhornDistance(eps=eplisons, max_iter=200, dis='cos', reduction='mean').to('cuda')
        
        # Setup logger
        self.logger = Logger(self.training_opt['log_dir'])
        
        self.beta = 0.99
        effective_num = 1.0 - np.power(self.beta, self.img_num_list)
        per_cls_weights = (1.0 - self.beta) / np.array(effective_num)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(self.img_num_list)
        self.per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()
        
        # Initialize model
        self.init_models()

        # Under training mode, initialize training steps, optimizers, schedulers, criterions, and centroids
        if not self.test_mode:

            print('Using steps for training.')
            self.training_data_num = len(self.data['train'].dataset)
            self.epoch_steps = int(self.training_data_num  \
                                   / self.training_opt['batch_size'])

            # Initialize model optimizer and scheduler
            print('Initializing model optimizer.')
            self.scheduler_params = self.training_opt['scheduler_params']
            self.model_optimizer, \
            self.model_optimizer_scheduler = self.init_optimizers(self.model_optim_params_list)
            self.init_criterions()
            
            # Set up log file
            self.log_file = os.path.join(self.training_opt['log_dir'], 'log.txt')
            if os.path.isfile(self.log_file):
                os.remove(self.log_file)
            self.logger.log_cfg(self.config)
        else:
            self.log_file = None

    def init_models(self, optimizer=True):
        self.model_optim_params_list = []

        print("Using", torch.cuda.device_count(), "GPUs.")

        self.visual_model = torch.nn.DataParallel(self.clip_model.visual).cuda()
        text_model = TextEncoder(self.clip_model)

        self.text_model = torch.nn.DataParallel(text_model).cuda()

        feat_dim = self.model_opt['adapter']['params']['feat_dim']

        self.adapter = torch.nn.DataParallel(nn.Sequential(
            nn.Linear(feat_dim, feat_dim, bias=False),
            )).cuda()
        
        
        self.fc = nn.Linear(feat_dim, self.config['training_opt']['num_classes'], bias=False).cuda()
                
        if self.training_opt['phaseA'] is not True:
            self.load_model(self.config['model_dir'])

            for param_name, param in self.visual_model.named_parameters():
                param.requires_grad = False

            for param_name, param in self.text_model.named_parameters():
                param.requires_grad = False

            for param_name, param in self.clip_model.named_parameters():
                param.requires_grad = False

            self.clip_model.eval()
            self.visual_model.eval()
            self.text_model.eval()

        optim_params_adapter = self.model_opt['adapter']['optim_params']
        self.model_optim_params_list.append({'params': self.adapter.parameters(),
                                                'lr': optim_params_adapter['lr'],
                                                'momentum': optim_params_adapter['momentum'],
                                                'weight_decay': optim_params_adapter['weight_decay']})
        
        optim_params_fc = self.model_opt['fc']['optim_params']
        self.model_optim_params_list.append({'params': self.fc.parameters(),
                                                'lr': optim_params_fc['lr'],
                                                'momentum': optim_params_fc['momentum'],
                                                'weight_decay': optim_params_fc['weight_decay']})
        
        optim_params_clip = self.model_opt['clip']['optim_params']
        self.model_optim_params_list.append({'params': self.visual_model.parameters(),
                                                'lr': optim_params_clip['lr'],
                                                'momentum': optim_params_clip['momentum'],
                                                'weight_decay': optim_params_clip['weight_decay']})

        self.model_optim_params_list.append({'params': self.text_model.parameters(),
                                                'lr': optim_params_clip['lr'],
                                                'momentum': optim_params_clip['momentum'],
                                                'weight_decay': optim_params_clip['weight_decay']})
        
        self.learnable_weights = nn.Parameter(self.per_cls_weights.data, requires_grad=True)

        optim_params_lws = self.model_opt['fc']['optim_params']
        self.model_optim_params_list.append({'params': self.learnable_weights,
                                                'lr': 0.01,
                                                'momentum': optim_params_lws['momentum'],
                                                'weight_decay': optim_params_lws['weight_decay']})
        

    def init_criterions(self):
        criterion_defs = self.config['criterions']
        self.criterions = {}
        self.criterion_weights = {}

        for key, val in criterion_defs.items():
            def_file = val['def_file']
            loss_args = list(val['loss_params'].values())

            self.criterions[key] = source_import(def_file).create_loss(*loss_args).cuda()
            self.criterion_weights[key] = val['weight']
          
            if val['optim_params']:
                print('Initializing criterion optimizer.')
                optim_params = val['optim_params']
                optim_params = [{'params': self.criterions[key].parameters(),
                                'lr': optim_params['lr'],
                                'momentum': optim_params['momentum'],
                                'weight_decay': optim_params['weight_decay']}]
                # Initialize criterion optimizer and scheduler
                self.criterion_optimizer, \
                self.criterion_optimizer_scheduler = self.init_optimizers(optim_params)
            else:
                self.criterion_optimizer = None

    def init_optimizers(self, optim_params):
        
        optimizer = optim.SGD(optim_params)
        
        if self.config['coslr']:
            print("===> Using coslr eta_min={}".format(self.config['endlr']))
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, self.training_opt['num_epochs'], eta_min=self.config['endlr'])
        else:
            scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                  step_size=self.scheduler_params['step_size'],
                                                  gamma=self.scheduler_params['gamma'])
        return optimizer, scheduler

    def batch_forward(self, inputs, phase='train', indexes=None, labels=None):
        '''
        This is a general single batch running function. 
        '''

        meta_classes = obtain_labels(self.config["training_opt"]["dataset"])
        classesname, templates = meta_classes["CLASSES"], meta_classes["CUSTOM_TEMPLATES"]
        
        multi_prompt = prompt_templates

        if self.model_opt['clip']['params']['visual_backbone'].startswith("ViT"):
            inputs = inputs.to(torch.float16)

        if phase == 'test':
            _bs, _shot = inputs.shape[0], inputs.shape[1]
            # import pdb;pdb.set_trace()
            inputs = inputs.view(_bs * _shot, 3, 224, 224)

        image_features = self.visual_model(inputs).float()
        x = image_features

        ratio = 0.2
        if self.training_opt['phaseA'] is not True and phase == 'train':
            
            outputs = self.adapter(x)
            
            outputs = ratio * outputs + (1-ratio) * x
            outputs = F.normalize(outputs, dim=1)
            self.logits = self.fc(outputs)
            
        if phase == 'test':
            
            outputs = self.adapter(x)
            outputs = ratio * outputs + (1-ratio) * x
            outputs = F.normalize(outputs)
            self.logits = self.fc(outputs)
            self.logits = self.logits.view(_bs, _shot, -1).mean(dim=1).squeeze()
            # import pdb;pdb.set_trace()
            

        self.features = outputs
        self.labels = labels
            
        self.ori_features = x

    def batch_backward(self):
        # Zero out optimizer gradients
        self.model_optimizer.zero_grad()
        if self.criterion_optimizer:
            self.criterion_optimizer.zero_grad()
        # Back-propagation from loss outputs
        
        self.loss.backward()
        # Step optimizers
        self.model_optimizer.step()
        if self.criterion_optimizer:
            self.criterion_optimizer.step()
        
    def batch_loss(self, labels):
        self.loss = 0

        # First, apply performance loss
        if 'PerformanceLoss' in self.criterions.keys():

            self.loss_perf = self.criterions['PerformanceLoss'](self.logits*50, self.labels)

            aa = torch.ones(self.training_opt['batch_size']).to(self.device) / self.training_opt['batch_size']
            bb = torch.ones(self.training_opt['num_classes']).to(self.device) / self.training_opt['num_classes']
        
            _cost, _pi, C, U, V = self.ot_criterion(self.fc.weight.data, self.features, aa)
            _pi = F.normalize(_pi)

            _cst = 0

            # 对于每个样本，计算与同类和不同类样本之间的L1距离
            for i in range(self.training_opt['batch_size']):
                same_class_indices = (self.labels == self.labels[i])
                diff_class_indices = (self.labels != self.labels[i])

                same_class_samples = _pi.T[same_class_indices]
                diff_class_samples = _pi.T[diff_class_indices]

                _dis_positive = torch.cdist(_pi.T[i].softmax(-1).unsqueeze(0), same_class_samples.softmax(-1), p=1).mean()
                _dis_negative = torch.cdist(_pi.T[i].softmax(-1).unsqueeze(0), diff_class_samples.softmax(-1), p=1).mean()

                _cst += (_dis_positive / _dis_negative)
            
            _cst /= self.training_opt['batch_size']

            self.loss = self.loss + self.loss_perf.mean() + 0.1 * _cost + 0.1 * _cst
            # + 2 * _cost
            # + 0.1 * _cost + 0.1 *_cst
            # self.loss = self.loss + 10 * _cost + 1 *_cst

        # Apply loss on features if set up
        if 'FeatureLoss' in self.criterions.keys():
            self.loss_feat = self.criterions['FeatureLoss'](self.features, labels)
            self.loss_feat = self.loss_feat * self.criterion_weights['FeatureLoss']
            # Add feature loss to total loss
            self.loss += self.loss_feat

    def shuffle_batch(self, x, y):
        index = torch.randperm(x.size(0))
        x = x[index]
        y = y[index]
        return x, y

    def train(self):
        # When training the network
        print_str = ['Phase: train']
        print_write(print_str, self.log_file)
        time.sleep(0.25)

        print_write(['Do shuffle??? --- ', self.do_shuffle], self.log_file)

        # Initialize best model
        best_model_weights = {}
        best_model_weights['visual_model'] = copy.deepcopy(self.visual_model.state_dict())
        best_model_weights['text_model'] = copy.deepcopy(self.text_model.state_dict())

        if self.training_opt['phaseA'] is not True:
            best_model_weights['classifier'] = copy.deepcopy(self.adapter.state_dict())
            
        best_acc = 0.0
        best_epoch = 0
        # best_centroids = self.centroids

        end_epoch = self.training_opt['num_epochs']

        # Loop over epochs
        if not os.path.exists("fc_weights/"+f'{self.config["training_opt"]["dataset"]}_{self.model_opt["clip"]["params"]["visual_backbone"][:4]}_multi_zeroshot_weights.pt'):
            
            meta_classes = obtain_labels(self.config["training_opt"]["dataset"])
            classesname, templates = meta_classes["CLASSES"], meta_classes["CUSTOM_TEMPLATES"]
            
            multi_prompt = prompt_templates

            zeroshot_weights = []
            before_zeroshot_weights = []
            
            for i in range(len(classesname)):
                _tmp_class_prompt = torch.cat([clip.tokenize(multi_prompt[mp].format(classesname[i])) for mp in range(len(multi_prompt))]).to('cuda')

                _zeroshot_weights = self.text_model(_tmp_class_prompt, None).float()

                before_zeroshot_weights.append(_zeroshot_weights.detach().cpu().numpy())

                _zeroshot_weights = _zeroshot_weights / _zeroshot_weights.norm(dim=-1, keepdim=True)
                zeroshot_weights.append(_zeroshot_weights.detach().cpu().numpy())

            zeroshot_weights = np.asarray(zeroshot_weights)
            self.multi_zeroshot_weights = torch.tensor(zeroshot_weights).to(self.device).reshape(len(classesname), len(multi_prompt), -1)
            torch.save(self.multi_zeroshot_weights, "fc_weights/"+f'{self.config["training_opt"]["dataset"]}_{self.model_opt["clip"]["params"]["visual_backbone"][:4]}_multi_zeroshot_weights.pt')

            prompts = torch.cat([clip.tokenize(templates.format(a)) for a in classesname]).to(self.device)
            tokenized_prompts = None
            self.prompts_for_prints = prompts
            zeroshot_weights = self.text_model(prompts, tokenized_prompts).float()
            self.zeroshot_weights = zeroshot_weights / zeroshot_weights.norm(dim=-1, keepdim=True)
            torch.save(self.zeroshot_weights, "fc_weights/"+f'{self.config["training_opt"]["dataset"]}_{self.model_opt["clip"]["params"]["visual_backbone"][:4]}_single_zeroshot_weights.pt')

        self.multi_zeroshot_weights = torch.load("fc_weights/"+f'{self.config["training_opt"]["dataset"]}_{self.model_opt["clip"]["params"]["visual_backbone"][:4]}_multi_zeroshot_weights.pt').to('cuda')

        random_prompt = torch.randperm(80)
        self.mean_multi = self.multi_zeroshot_weights[:, random_prompt, :].mean(1).squeeze()
        self.zeroshot_weights = torch.load("fc_weights/"+f'{self.config["training_opt"]["dataset"]}_{self.model_opt["clip"]["params"]["visual_backbone"][:4]}_single_zeroshot_weights.pt').to('cuda')
        
        wiseft_ratio = 1         
        
        self.fc.weight.data = (1 - wiseft_ratio) * self.fc.weight.data + wiseft_ratio * F.normalize(self.mean_multi, dim=1)
        
        self.ldam = LDAMLoss(cls_num_list=self.img_num_list, max_m=0.5, s=30, weight=self.per_cls_weights)
        self.mislas_criterions = LabelAwareSmoothing(self.img_num_list)
        self.balanced_softmax = BalancedSoftmaxLoss(self.img_num_list)
        self.classbalance_softmax = ClassBalancedLoss(self.img_num_list, beta=self.beta)
        self.focal = FocalLoss()
        
        self.adjust_logit = compute_adjustment(self.img_num_list, 1)
        # import pdb;pdb.set_trace()

        for epoch in range(1, end_epoch + 1):
            
            torch.cuda.empty_cache()
            
            # Set model modes and set scheduler
            # In training, step optimizer scheduler and set model to train() 
            self.model_optimizer_scheduler.step()
            if self.criterion_optimizer:
                self.criterion_optimizer_scheduler.step()

            # Iterate over dataset
            total_preds = []
            total_labels = []

            print(self.fc.weight)

            for step, (inputs, labels, indexes) in enumerate(self.data['train']):
            
                if step == self.epoch_steps:
                    break
                
                inputs, labels = inputs.cuda(), labels.cuda()

                inputs, labels = self.shuffle_batch(inputs, labels)

                with torch.set_grad_enabled(True):
                 
                    # If training, forward with loss, and no top 5 accuracy calculation
                    self.epoch = epoch  
                    self.batch_forward(inputs,
                                       phase='train',
                                       indexes=indexes,
                                       labels=labels)
                    self.batch_loss(labels)
                    
                    if self.training_opt['dataset'] == 'Places_LT':
                        if (epoch-1) // 10 % 2 == 0:
                            self.adapter.requires_grad_(True)
                            self.fc.requires_grad_(False)
                        else:
                            self.adapter.requires_grad_(False)
                            self.fc.requires_grad_(True)
                    else:
                        self.adapter.requires_grad_(True)
                        self.fc.requires_grad_(True)
                    
                    self.batch_backward()

                    # Tracking predictions
                    _, preds = torch.max(self.logits, 1)
                    total_preds.append(torch2numpy(preds))
                    total_labels.append(torch2numpy(self.labels))
                    # total_labels.append(torch2numpy(self.cat_labels))

                    # Output minibatch training results
                    if step % self.training_opt['display_step'] == 0:

                        minibatch_loss_feat = self.loss_feat.item() \
                            if 'FeatureLoss' in self.criterions.keys() else None
                        
                        minibatch_loss_perf = self.loss.item()
                            
                        minibatch_loss_total = self.loss.item()
                        minibatch_acc = mic_acc_cal(preds, labels)

                        print_str = ['Epoch: [%d/%d]' 
                                     % (epoch, self.training_opt['num_epochs']),
                                     'Step: %5d' 
                                     % (step),
                                     'Minibatch_loss_feature: %.3f' 
                                     % (minibatch_loss_feat) if minibatch_loss_feat else '',
                                     'Minibatch_loss_performance: %.3f'
                                     % (minibatch_loss_perf) if minibatch_loss_perf else '',
                                     'Minibatch_accuracy_micro: %.3f'
                                      % (minibatch_acc)]
                        print_write(print_str, self.log_file)

                        loss_info = {
                            'Epoch': epoch,
                            'Step': step,
                            'Total': minibatch_loss_total,
                            'CE': minibatch_loss_perf,
                            'feat': minibatch_loss_feat
                        }

                        self.logger.log_loss(loss_info)
                
                if hasattr(self.data['train'].sampler, 'update_weights'):
                    if hasattr(self.data['train'].sampler, 'ptype'):
                        ptype = self.data['train'].sampler.ptype 
                    else:
                        ptype = 'score'
                    ws = get_priority(ptype, self.logits.detach(), labels)
                    # ws = logits2score(self.logits.detach(), labels)
                    inlist = [indexes.cpu().numpy(), ws]
                    if self.training_opt['sampler']['type'] == 'ClassPrioritySampler':
                        inlist.append(labels.cpu().numpy())
                    self.data['train'].sampler.update_weights(*inlist)
                    # self.data['train'].sampler.update_weights(indexes.cpu().numpy(), ws)

            if hasattr(self.data['train'].sampler, 'get_weights'):
                self.logger.log_ws(epoch, self.data['train'].sampler.get_weights())
            if hasattr(self.data['train'].sampler, 'reset_weights'):
                self.data['train'].sampler.reset_weights(epoch)

            # After every epoch, validation
            rsls = {'epoch': epoch}
            
            rsls_train = self.eval_with_preds(total_preds, total_labels)
            rsls_eval = self.eval(phase='test')
            rsls.update(rsls_train)
            rsls.update(rsls_eval)

            # Reset class weights for sampling if pri_mode is valid
            if hasattr(self.data['train'].sampler, 'reset_priority'):
                ws = get_priority(self.data['train'].sampler.ptype,
                                  self.total_logits.detach(),
                                  self.total_labels)
                self.data['train'].sampler.reset_priority(ws, self.total_labels.cpu().numpy())

            # Log results
            self.logger.log_acc(rsls)

            # Under validation, the best model need to be updated
            if self.eval_acc_mic_top1 > best_acc:
                best_epoch = epoch
                best_acc = self.eval_acc_mic_top1
                #best_centroids = self.centroids
                best_model_weights['visual_model'] = copy.deepcopy(self.visual_model.state_dict())
                best_model_weights['text_model'] = copy.deepcopy(self.text_model.state_dict())
                
                if self.training_opt['phaseA'] is not True:
                    best_model_weights['adapter'] = copy.deepcopy(self.adapter.state_dict())
                    best_model_weights['classifier'] = copy.deepcopy(self.fc.state_dict())
                    
                self.save_model(epoch, best_epoch, best_model_weights, best_acc)
            
            print('===> Saving checkpoint')
            print('Best Eval All Acc: ', best_acc)
            self.save_latest(epoch)

        print()
        print('Training Complete.')

        print_str = ['Best validation accuracy is %.3f at epoch %d' % (best_acc, best_epoch)]
        print(print_str)
        print_write(print_str, self.log_file)   
        # Save the best model and best centroids if calculated
        self.save_model(epoch, best_epoch, best_model_weights, best_acc)

        # Test on the test set
        # self.reset_model(best_model_weights)
        self.eval('test' if 'test' in self.data else 'val')

        print(self.training_opt['log_dir'])
        print('Done')


    def eval_with_preds(self, preds, labels):
        # Count the number of examples
        n_total = sum([len(p) for p in preds])

        # Split the examples into normal and mixup
        normal_preds, normal_labels = [], []
        mixup_preds, mixup_labels1, mixup_labels2, mixup_ws = [], [], [], []
        for p, l in zip(preds, labels):
            if isinstance(l, tuple):
                mixup_preds.append(p)
                mixup_labels1.append(l[0])
                mixup_labels2.append(l[1])
                mixup_ws.append(l[2] * np.ones_like(l[0]))
            else:
                normal_preds.append(p)
                normal_labels.append(l)
        
        # Calculate normal prediction accuracy
        rsl = {'train_all':0., 'train_many':0., 'train_median':0., 'train_low': 0.}
        if len(normal_preds) > 0:
            normal_preds, normal_labels = list(map(np.concatenate, [normal_preds, normal_labels]))
            n_top1 = mic_acc_cal(normal_preds, normal_labels)
            n_top1_many, \
            n_top1_median, \
            n_top1_low, = shot_acc(normal_preds, normal_labels, self.data['train'])
            rsl['train_all'] += len(normal_preds) / n_total * n_top1
            rsl['train_many'] += len(normal_preds) / n_total * n_top1_many
            rsl['train_median'] += len(normal_preds) / n_total * n_top1_median
            rsl['train_low'] += len(normal_preds) / n_total * n_top1_low

        # Calculate mixup prediction accuracy
        if len(mixup_preds) > 0:
            mixup_preds, mixup_labels, mixup_ws = \
                list(map(np.concatenate, [mixup_preds*2, mixup_labels1+mixup_labels2, mixup_ws]))
            mixup_ws = np.concatenate([mixup_ws, 1-mixup_ws])
            n_top1 = weighted_mic_acc_cal(mixup_preds, mixup_labels, mixup_ws)
            n_top1_many, \
            n_top1_median, \
            n_top1_low, = weighted_shot_acc(mixup_preds, mixup_labels, mixup_ws, self.data['train'])
            rsl['train_all'] += len(mixup_preds) / 2 / n_total * n_top1
            rsl['train_many'] += len(mixup_preds) / 2 / n_total * n_top1_many
            rsl['train_median'] += len(mixup_preds) / 2 / n_total * n_top1_median
            rsl['train_low'] += len(mixup_preds) / 2 / n_total * n_top1_low

        # Top-1 accuracy and additional string
        print_str = ['\n Training acc Top1: %.3f \n' % (rsl['train_all']),
                     'Many_top1: %.3f' % (rsl['train_many']),
                     'Median_top1: %.3f' % (rsl['train_median']),
                     'Low_top1: %.3f' % (rsl['train_low']),
                     '\n']
        print_write(print_str, self.log_file)

        return rsl

    def eval(self, phase='val', openset=False, save_feat=False):

        print_str = ['Phase: %s' % (phase)]
        print_write(print_str, self.log_file)
        time.sleep(0.25)

        if openset:
            print('Under openset test mode. Open threshold is %.1f' 
                  % self.training_opt['open_threshold'])
 
        torch.cuda.empty_cache()

        self.total_logits = torch.empty((0, self.training_opt['num_classes'])).cuda()
        self.total_labels = torch.empty(0, dtype=torch.long).cuda()
        self.total_paths = np.empty(0)

        get_feat_only = save_feat
        feats_all, labels_all, idxs_all, logits_all = [], [], [], []
        adapted_all = []
        featmaps_all = []
        
        # Iterate over dataset
        for inputs, labels, paths in tqdm(self.data[phase]):
            inputs, labels = inputs.cuda(), labels.cuda()

            # If on training phase, enable gradients
            with torch.set_grad_enabled(False):

                # In validation or testing
                self.batch_forward(inputs, phase=phase)
                if not get_feat_only:
                    self.total_logits = torch.cat((self.total_logits, self.logits))
                    self.total_labels = torch.cat((self.total_labels, labels))
                    self.total_paths = np.concatenate((self.total_paths, paths))

                if get_feat_only:
                    logits_all.append(self.logits.cpu().numpy())
                    feats_all.append(self.features.cpu().numpy())
                    adapted_all.append(self.ori_features.cpu().numpy())
                    labels_all.append(labels.cpu().numpy())
                    idxs_all.append(paths.numpy())

        if get_feat_only:
            typ = 'feat'
            if phase == 'train_plain':
                name = 'train{}_all.pkl'.format(typ)
            elif phase == 'test':
                name = 'test{}_all.pkl'.format(typ)
            elif phase == 'val':
                name = 'val{}_all.pkl'.format(typ)

            fname = os.path.join(self.training_opt['log_dir'], name)
            print('===> Saving feats to ' + fname)
            
            with open(fname, 'wb') as f:
                pickle.dump({
                            'logits': np.concatenate(logits_all),
                             'feats': np.concatenate(feats_all),
                             'ori_features': np.concatenate(adapted_all),
                             'labels': np.concatenate(labels_all),
                             'idxs': np.concatenate(idxs_all),
                            },
                            f, protocol=4) 
            return 
        
        probs, preds = F.softmax(self.total_logits.detach(), dim=1).max(dim=1)

        if openset:
            preds[probs < self.training_opt['open_threshold']] = -1
            self.openset_acc = mic_acc_cal(preds[self.total_labels == -1],
                                            self.total_labels[self.total_labels == -1])
            print('\n\nOpenset Accuracy: %.3f' % self.openset_acc)

        # Calculate the overall accuracy and F measurement
        self.eval_acc_mic_top1= mic_acc_cal(preds[self.total_labels != -1],
                                            self.total_labels[self.total_labels != -1])
        self.eval_f_measure = F_measure(preds, self.total_labels, openset=openset,
                                        theta=self.training_opt['open_threshold'])
        self.many_acc_top1, \
        self.median_acc_top1, \
        self.low_acc_top1, \
        self.cls_accs = shot_acc(preds[self.total_labels != -1],
                                 self.total_labels[self.total_labels != -1], 
                                 self.data['train'],
                                 acc_per_cls=True)
        # Top-1 accuracy and additional string
        print_str = ['\n\n',
                     'Phase: %s' 
                     % (phase),
                     '\n\n',
                     'Evaluation_accuracy_micro_top1: %.3f' 
                     % (self.eval_acc_mic_top1),
                     '\n',
                     'Averaged F-measure: %.3f' 
                     % (self.eval_f_measure),
                     '\n',
                     'Many_shot_accuracy_top1: %.3f' 
                     % (self.many_acc_top1),
                     'Median_shot_accuracy_top1: %.3f' 
                     % (self.median_acc_top1),
                     'Low_shot_accuracy_top1: %.3f' 
                     % (self.low_acc_top1),
                     '\n']
        
        rsl = {phase + '_all': self.eval_acc_mic_top1,
               phase + '_many': self.many_acc_top1,
               phase + '_median': self.median_acc_top1,
               phase + '_low': self.low_acc_top1,
               phase + '_fscore': self.eval_f_measure}

        if phase == 'val':
            print_write(print_str, self.log_file)
        else:
            acc_str = ["{:.1f} \t {:.1f} \t {:.1f} \t {:.1f}".format(
                self.many_acc_top1 * 100,
                self.median_acc_top1 * 100,
                self.low_acc_top1 * 100,
                self.eval_acc_mic_top1 * 100)]
            if self.log_file is not None and os.path.exists(self.log_file):
                print_write(print_str, self.log_file)
                print_write(acc_str, self.log_file)
            else:
                print(*print_str)
                print(*acc_str)
        
        if phase == 'test':
            with open(os.path.join(self.training_opt['log_dir'], 'cls_accs.pkl'), 'wb') as f:
                pickle.dump(self.cls_accs, f)
        return rsl

    def load_model(self, model_dir=None):
        model_dir = self.training_opt['log_dir'] if model_dir is None else model_dir
        if not model_dir.endswith('.pth'):
            print('No pretrained Phase A model')
        
        print('Validation on the best model.')
        print('Loading model from %s' % (model_dir))
        
        checkpoint = torch.load(model_dir, map_location='cpu')     

        model_state = checkpoint['state_dict_best']            
        self.text_model.load_state_dict(model_state['text_model'])
        self.visual_model.load_state_dict(model_state['visual_model'])
    
    def save_latest(self, epoch):
        model_weights = {}
        model_weights['visual_model'] = copy.deepcopy(self.visual_model.state_dict())
        model_weights['text_model'] = copy.deepcopy(self.text_model.state_dict())

        if self.training_opt['phaseA'] is not True:
            model_weights['classifier'] = copy.deepcopy(self.adapter.state_dict())

        model_states = {
            'epoch': epoch,
            'state_dict': model_weights
        }

        model_dir = os.path.join(self.training_opt['log_dir'], 
                                 'latest_model_checkpoint.pth')
        torch.save(model_states, model_dir)
        
    def save_model(self, epoch, best_epoch, best_model_weights, best_acc, centroids=None):
        
        model_states = {'epoch': epoch,
                'best_epoch': best_epoch,
                'state_dict_best': best_model_weights,
                'best_acc': best_acc,
                'centroids': centroids}

        model_dir = os.path.join(self.training_opt['log_dir'], 
                                 'final_model_checkpoint.pth')

        torch.save(model_states, model_dir)
            
    def output_logits(self, openset=False):
        filename = os.path.join(self.training_opt['log_dir'], 
                                'logits_%s'%('open' if openset else 'close'))
        print("Saving total logits to: %s.npz" % filename)
        np.savez(filename, 
                 logits=self.total_logits.detach().cpu().numpy(), 
                 labels=self.total_labels.detach().cpu().numpy(),
                 paths=self.total_paths)

