import os
import time
import datetime
import numpy as np
from tqdm import tqdm
from collections import OrderedDict, defaultdict
from sklearn.linear_model import LogisticRegression
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import transforms

from clip import clip
from timm.models.vision_transformer import vit_base_patch16_224, vit_large_patch16_224
from timm.models.deit import deit_tiny_patch16_224, deit_small_patch16_224, deit_base_patch16_224, deit_base_patch16_384
from timm.models.deit import deit_tiny_distilled_patch16_224, deit_small_distilled_patch16_224, deit_base_distilled_patch16_224, deit_base_distilled_patch16_384
from timm.models.regnet import regnety_160

import datasets
from models import *

from utils.meter import AverageMeter
from utils.samplers import DownSampler, ClassAwareSampler, SqurerootSampler
from utils.MixedPrioritizedSampler import MixedPrioritizedSampler
from utils.losses import *
from utils.evaluator import Evaluator

def label_smooth(labels, class_count, epsilon):
    assert 0 <= epsilon <= 1.0
    
    labels = F.one_hot(labels, class_count).float()
    assert labels.shape[-1] == class_count 
    
    confidence = 1.0 - epsilon
    smooth_labels = labels * confidence + epsilon / class_count
    return smooth_labels

def load_clip_to_cpu(backbone_name, prec):
    backbone_name = backbone_name.lstrip("CLIP-")
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)
    
    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu").eval()

    model = clip.build_model(state_dict or model.state_dict())

    assert prec in ["fp16", "fp32", "amp"]
    if prec == "fp32" or prec == "amp":
        # CLIP's default precision is fp16
        model.float()

    return model

def load_vit_to_cpu(backbone_name, prec):
    if backbone_name == "IN21K-ViT-B/16":
        model = vit_base_patch16_224(pretrained=True).eval()
    elif backbone_name == "IN21K-ViT-L/16":
        model = vit_large_patch16_224(pretrained=True).eval()
    elif backbone_name == "DeiT-Ti":
        #model = deit_tiny_patch16_224(pretrained=True).eval()
        model =  deit_tiny_distilled_patch16_224(pretrained=True).eval()
    elif backbone_name == "DeiT-S":
        #model = deit_small_patch16_224(pretrained=True).eval()
        model = deit_small_distilled_patch16_224(pretrained=True).eval()
    elif backbone_name == "DeiT-B":
        #model = deit_base_patch16_224(pretrained=True).eval()
        model = deit_base_distilled_patch16_224(pretrained=True).eval()
    elif backbone_name == "DeiT-B@384px":
        #model = deit_base_patch16_384(pretrained=True).eval()
        model = deit_base_distilled_patch16_384(pretrained=True).eval()
    elif backbone_name == "RegNetY":
        model = regnety_160(pretrained=True).eval()

    assert prec in ["fp16", "fp32", "amp"]
    if prec == "fp16":
        # ViT's default precision is fp32
        model.half()
    
    return model


class Trainer:
    def __init__(self, cfg):

        if not torch.cuda.is_available():
            self.device = torch.device("cpu")
        elif cfg.gpu is None:
            self.device = torch.device("cuda")
        else:
            torch.cuda.set_device(cfg.gpu)
            self.device = torch.device("cuda:{}".format(cfg.gpu))

        self.cfg = cfg
        self.build_data_loader()
        self.build_model()
        self.evaluator = Evaluator(cfg, self.many_idxs, self.med_idxs, self.few_idxs)
        self._writer = None

    def build_data_loader(self):
        cfg = self.cfg
        root = cfg.root
        resolution = cfg.resolution
        expand = cfg.expand

        if cfg.backbone.startswith("CLIP"):
            mean = [0.48145466, 0.4578275, 0.40821073]
            std = [0.26862954, 0.26130258, 0.27577711]
        else:
            mean = [0.5, 0.5, 0.5]
            std = [0.5, 0.5, 0.5]
        print("mean:", mean)
        print("std:", std)

        def identity(img):
            return img
        
        if cfg.augmentation == "ColorJitter":
            augmentation = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0)
        elif cfg.augmentation == "RandAugment":
            augmentation = transforms.RandAugment()
        elif cfg.augmentation == "AutoAugment":
            augmentation = transforms.AutoAugment()
        elif cfg.augmentation is None:
            augmentation = transforms.Lambda(lambda x:x)
        else:
            raise NameError("Invalid augmentation.")
            
                  
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(resolution),
            transforms.RandomHorizontalFlip(),
            augmentation,               #colorjitter
            # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0),     
            # transforms.RandAugment(),
            # transforms.AutoAugment(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        

        transform_plain = transforms.Compose([
            transforms.Resize(resolution),
            transforms.CenterCrop(resolution),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        if cfg.test_ensemble:
            transform_test = transforms.Compose([
                transforms.Resize(resolution + expand),
                transforms.FiveCrop(resolution),
                transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                transforms.Normalize(mean, std),
            ])
        else:
            transform_test = transforms.Compose([
                transforms.Resize(resolution * 8 // 7),
                transforms.CenterCrop(resolution),
                transforms.Lambda(lambda crop: torch.stack([transforms.ToTensor()(crop)])),
                transforms.Normalize(mean, std),
            ])

        train_dataset = getattr(datasets, cfg.dataset)(root, train=True, transform=transform_train, drop_shot=cfg.drop_shot)
        train_init_dataset = getattr(datasets, cfg.dataset)(root, train=True, transform=transform_plain, drop_shot=cfg.drop_shot)
        train_test_dataset = getattr(datasets, cfg.dataset)(root, train=True, transform=transform_test, drop_shot=cfg.drop_shot)
        test_dataset = getattr(datasets, cfg.dataset)(root, train=False, transform=transform_test, drop_shot=cfg.drop_shot)

        self.num_classes = train_dataset.num_classes    #
        self.cls_num_list = train_dataset.cls_num_list  #
        self.classnames = train_dataset.classnames      #
        
        print("num_classes",self.num_classes)
        print("cls_num_list",sum(self.cls_num_list),self.cls_num_list)
        print("classnames",self.classnames)
        
        
        if cfg.dataset in ["CIFAR100", "CIFAR100_IR10", "CIFAR100_IR50"]:
            split_cls_num_list = datasets.CIFAR100_IR100(root, train=True).cls_num_list
        else:
            split_cls_num_list = self.cls_num_list
        self.many_idxs = (np.array(split_cls_num_list) > 100).nonzero()[0]
        self.med_idxs = ((np.array(split_cls_num_list) >= 20) & (np.array(split_cls_num_list) <= 100)).nonzero()[0]
        self.few_idxs = (np.array(split_cls_num_list) < 20).nonzero()[0]

        # print(self.many_idxs, self.med_idxs, self.few_idxs, cfg.init_head)
        # assert False
        
        if cfg.init_head == "1_shot":
            init_sampler = DownSampler(train_init_dataset, n_max=1)
        elif cfg.init_head == "10_shot":
            init_sampler = DownSampler(train_init_dataset, n_max=10)
        elif cfg.init_head == "100_shot":
            init_sampler = DownSampler(train_init_dataset, n_max=100)
        else:
            init_sampler = None

        
        if cfg.ensemble:        
            self.train_loader_one = DataLoader(train_dataset, batch_size=cfg.micro_batch_size, shuffle=True,num_workers=cfg.num_workers, pin_memory=True)
            
            resample_sampler = ClassAwareSampler(train_dataset, category="EQ")     
            self.train_loader_two = DataLoader(train_dataset, batch_size=cfg.micro_batch_size, sampler=resample_sampler, num_workers=cfg.num_workers, pin_memory=True)
        
        elif cfg.resampling:
            if cfg.resampling == "squreroot":
                # resample_sampler = MixedPrioritizedSampler(dataset=train_dataset,  balance_scale=1.0, fixed_scale=1.0,
                #                                             lam=1.0, epochs=90, cycle=0, nroot=2.0, manual_only=True,
                #                                             rescale=False, root_decay=None, decay_gap=30, ptype='score',
                #                                             alpha=1.0)
                resampling_cls_num_list = np.array(self.cls_num_list)
                resampling_cls_num_list = np.sqrt(resampling_cls_num_list / min(resampling_cls_num_list)) * min(resampling_cls_num_list)
                resampling_cls_num_list = [int(n) for n in resampling_cls_num_list]
                
                self.resampling_cls_list = resampling_cls_num_list
                resample_sampler = SqurerootSampler(train_dataset, resampling_cls_num_list)
                
            else:
                resample_sampler = ClassAwareSampler(train_dataset, category=cfg.resampling)     #"ROS","RUS","MID","EQ"
                
                
            self.train_loader = DataLoader(train_dataset,         
                batch_size=cfg.micro_batch_size, sampler=resample_sampler,
                num_workers=cfg.num_workers, pin_memory=True)
            
        else:
            self.train_loader = DataLoader(train_dataset,           
                batch_size=cfg.micro_batch_size, shuffle=True,
                num_workers=cfg.num_workers, pin_memory=True)
        
        # check
        # l = [0 for _ in range(self.num_classes)]              
        # for batch_idx, batch in tqdm(enumerate(self.train_loader)):
        #         image = batch[0]
        #         label = batch[1]
        #         for i in label:
        #             l[i] += 1
        # print(l)
        # assert False
        
        self.train_init_loader = DataLoader(train_init_dataset,
            batch_size=64, sampler=init_sampler, shuffle=False,
            num_workers=cfg.num_workers, pin_memory=True)

        self.train_test_loader = DataLoader(train_test_dataset,
            batch_size=64, shuffle=False,
            num_workers=cfg.num_workers, pin_memory=True)

        self.test_loader = DataLoader(test_dataset,             
            batch_size=64, shuffle=False,
            num_workers=cfg.num_workers, pin_memory=True)
        
        assert cfg.batch_size % cfg.micro_batch_size == 0
        self.accum_step = cfg.batch_size // cfg.micro_batch_size

        # print("Total training points:", sum(self.cls_num_list))
        # print(self.cls_num_list)

    def build_model(self):
        cfg = self.cfg
        classnames = self.classnames
        num_classes = len(classnames)

        print("Building model")
        if cfg.zero_shot:
            assert cfg.backbone.startswith("CLIP")
            print(f"Loading CLIP (backbone: {cfg.backbone})")
            clip_model = load_clip_to_cpu(cfg.backbone, cfg.prec)
            self.model = ZeroShotCLIP(clip_model)
            self.model.to(self.device)
            self.tuner = None
            self.head = None

            prompts = self.get_tokenized_prompts(classnames)    
            self.model.init_text_features(prompts)              #text_feature

        elif cfg.backbone.startswith("CLIP"):
            print(f"Loading CLIP (backbone: {cfg.backbone})")
            clip_model = load_clip_to_cpu(cfg.backbone, cfg.prec)           #load model
            self.model = PeftModelFromCLIP(cfg, clip_model, num_classes)    
            #text_encoder in models/clip_text.py，image_encoder in models/peft_vit(rn).py
            self.model.to(self.device)
            self.tuner = self.model.tuner       #PEL in models/peft_modules.py
            self.head = self.model.head         #head   


        elif cfg.backbone.startswith("IN21K-ViT"):
            print(f"Loading ViT (backbone: {cfg.backbone})")
            vit_model = load_vit_to_cpu(cfg.backbone, cfg.prec)
            self.model = PeftModelFromViT(cfg, vit_model, num_classes)
            self.model.to(self.device)
            self.tuner = self.model.tuner
            self.head = self.model.head
        
        elif cfg.backbone.startswith("RegNetY"):
            print(f"Loading CNN (backbone: {cfg.backbone})")
            cnn_model = load_vit_to_cpu(cfg.backbone, cfg.prec)
            self.model = PeftModelFromCNN(cfg, cnn_model, num_classes)
            self.model.to(self.device)
            self.tuner = self.model.tuner
            self.head = self.model.head
            
        elif cfg.backbone.startswith("DeiT-"):
            print(f"Loading ViT (backbone: {cfg.backbone})")
            vit_model = load_vit_to_cpu(cfg.backbone, cfg.prec)
            self.model = PeftModelFromDeiT(cfg, vit_model, num_classes)
            # self.model = PeftModelFromDeiT_NTOK(cfg, vit_model, num_classes, cfg.num_tokens)     
            self.model.to(self.device)
            self.tuner = self.model.tuner
            self.head = self.model.head

        #load teacher model
        if cfg.teacher:
            print("Building teacher model")         
            if cfg.teacher.startswith("zs_CLIP"):
                clip_teacher_model = load_clip_to_cpu(cfg.teacher[3:], cfg.prec)
                self.teacher_model = ZeroShotCLIP(clip_teacher_model)
                self.teacher_model.to(self.device)
                prompts = self.get_tokenized_prompts(classnames)
                self.teacher_model.init_text_features(prompts)
            
            else:     
                path_dataset = cfg.dataset.lower()
               
                # path_teachermodeldict = {"CLIP-ViT-B/16":"clip_vit_b16",
                #                         "CLIP-ViT-L/14":"clip_vit_l14",
                #                         "CLIP-RN50":"clip_rn50",
                #                         "CLIP-RN101":"clip_rn101",
                #                         "RegNetY":"regnet",
                #                         "IN21K-ViT-B/16":"in21k_vit_b16",
                #                         "IN21K-ViT-L/16":"in21k_vit_l16"}
                # path_ft, path_teacher_model = cfg.teacher.split("_")[0], path_teachermodeldict[ cfg.teacher.split("_")[1] ]
                # load_path = './output/' + path_dataset + '_' + path_teacher_model + '_' + path_ft + '/checkpoint.pth.tar'
                path_teacher = cfg.teacher
                load_path = './output/' + path_dataset + "_" + path_teacher + '/checkpoint.pth.tar'
                
                if not os.path.exists(load_path):
                    raise FileNotFoundError('Checkpoint not found at "{}"'.format(load_path))
                checkpoint = torch.load(load_path, map_location=self.device)
                
                tuner_dict = checkpoint["tuner"]
                head_dict = checkpoint["head"]
                teacher_cfg = checkpoint["cfg"]
                print("teacher_cfg")
                print("============================================")
                print(teacher_cfg)
                print("============================================")
                
                # if cfg.teacher.startswith("ft_CLIP-ViT-B/16") or cfg.teacher.startswith("peft_CLIP-ViT-B/16") or cfg.teacher.startswith("peft_CLIP-ViT-L/14") or cfg.teacher.startswith("peft_CLIP-RN50") or cfg.teacher.startswith("peft_CLIP-RN101"):
                #     teacher_clip_model = load_clip_to_cpu(teacher_cfg.backbone, teacher_cfg.prec) 
                #     self.teacher_model = PeftModelFromCLIP(teacher_cfg, teacher_clip_model, num_classes)
                # if  cfg.teacher.startswith("peft_IN21K-ViT-B/16") or cfg.teacher.startswith("peft_IN21K-ViT-L/16"):
                #     teacher_vit_model = load_vit_to_cpu(teacher_cfg.backbone, teacher_cfg.prec)
                #     self.teacher_model = PeftModelFromViT(teacher_cfg, teacher_vit_model, num_classes)
                # if  cfg.teacher.startswith("peft_RegNetY"):
                #     teacher_cnn_model = load_vit_to_cpu(teacher_cfg.backbone, teacher_cfg.prec)
                #     self.teacher_model = PeftModelFromCNN(teacher_cfg, teacher_cnn_model, num_classes)
                if cfg.teacher.startswith("teacher_cifar") or cfg.teacher.startswith("teacher_imagenet"):
                    teacher_vit_model = load_vit_to_cpu(teacher_cfg.backbone, teacher_cfg.prec)
                    self.teacher_model = PeftModelFromViT(teacher_cfg, teacher_vit_model, num_classes)
                elif cfg.teacher.startswith("teacher_places"):
                    teacher_clip_model = load_clip_to_cpu(teacher_cfg.backbone, teacher_cfg.prec)
                    self.teacher_model = PeftModelFromCLIP(teacher_cfg, teacher_clip_model, num_classes)
                else:
                    raise NameError("invalid teacher name!")
                    
                self.teacher_model.to(self.device)
                
                print("Loading weights to from {}".format(load_path))
                self.teacher_model.tuner.load_state_dict(tuner_dict)
                if head_dict["weight"].shape == self.teacher_model.head.weight.shape:
                    self.teacher_model.head.load_state_dict(head_dict)
                
                
        if not (cfg.zero_shot or cfg.test_train or cfg.test_only):
            self.build_optimizer()
            self.build_criterion()

            if cfg.init_head == "text_feat":
                self.init_head_text_feat()
            elif cfg.init_head in ["class_mean", "1_shot", "10_shot", "100_shot"]:
                self.init_head_class_mean()
            elif cfg.init_head == "linear_probe":
                self.init_head_linear_probe()
            else:
                print("No initialization with head")
            
            torch.cuda.empty_cache()
        
        # Note that multi-gpu training could be slow because CLIP's size is
        # big, which slows down the copy operation in DataParallel
        device_count = torch.cuda.device_count()
        if device_count > 1 and cfg.gpu is None:
            print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
            self.model = nn.DataParallel(self.model)

    def build_optimizer(self):
        cfg = self.cfg

        print("Turning off gradients in the model")
        for name, param in self.model.named_parameters():        #freezing part model       
            if "dist_token" in name:
                param.requires_grad_(True)
            else:
                param.requires_grad_(False)
        
        if not cfg.freeze_encoder:
            print("Turning on gradients in the tuner")
            for name, param in self.tuner.named_parameters():   #tuner tuning
                param.requires_grad_(True)
        
        print("Turning on gradients in the head")
        for name, param in self.head.named_parameters():        #head tuning
            param.requires_grad_(True)
        
        for name, param in self.model.named_parameters():        #check
            if param.requires_grad:
                print(name)
        
        
        # print parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        tuned_params = sum(p.numel() for p in self.tuner.parameters())
        head_params = sum(p.numel() for p in self.head.parameters())
        print(f"Total params: {total_params}")
        print(f"Tuned params: {tuned_params}")
        print(f"Head params: {head_params}")
        
        if cfg.backbone.startswith("DeiT-"):
            self.optim = torch.optim.SGD([{"params": self.tuner.parameters()},
                                        {"params": self.head.parameters()},
                                        {"params": self.model.image_encoder.dist_token.parameters()},
                                        ],
                                        lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum)
        else:
            self.optim = torch.optim.SGD([{"params": self.tuner.parameters()},
                                        {"params": self.head.parameters()}
                                        ],
                                        lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum)

            
        self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, cfg.num_epochs)
        self.scaler = GradScaler() if cfg.prec == "amp" else None

    def build_criterion(self):
        cfg = self.cfg
        if cfg.resampling == "squreroot":
            cls_num_list = torch.Tensor(self.resampling_cls_list).to(self.device)
        else:
            cls_num_list = torch.Tensor(self.cls_num_list).to(self.device)
        

        if cfg.loss_type == "CE":
            self.criterion = nn.CrossEntropyLoss()
        elif cfg.loss_type == "Focal": # https://arxiv.org/abs/1708.02002
            self.criterion = FocalLoss(gamma=2)
        elif cfg.loss_type == "LDAM": # https://arxiv.org/abs/1906.07413
            self.criterion = LDAMLoss(cls_num_list=cls_num_list, s=cfg.scale)
        elif cfg.loss_type == "CB": # https://arxiv.org/abs/1901.05555
            self.criterion = ClassBalancedLoss(cls_num_list=cls_num_list, beta=0.9)
        elif cfg.loss_type == "GRW": # https://arxiv.org/abs/2103.16370
            self.criterion = GeneralizedReweightLoss(cls_num_list=cls_num_list, exp_scale=1.2) 
        elif cfg.loss_type == "BS": # https://arxiv.org/abs/2007.10740
            self.criterion = BalancedSoftmaxLoss(cls_num_list=cls_num_list)
        elif cfg.loss_type == "LA": # https://arxiv.org/abs/2007.07314
            self.criterion = LogitAdjustedLoss(cls_num_list=cls_num_list, tau=1.5)
        elif cfg.loss_type == "LADE": # https://arxiv.org/abs/2012.00321
            self.criterion = LADELoss(cls_num_list=cls_num_list, remine_lambda=0.01, estim_loss_weight=0.1)
        
        if cfg.ensemble:   
            self.criterion_one = nn.CrossEntropyLoss()
            self.criterion_two = nn.CrossEntropyLoss()
            
        if cfg.label_smooth:
            self.criterion = LabelSmoothLoss(self.criterion, self.num_classes, cfg.epsilon)
            
        if cfg.teacher:
            self.criterion = DistillationLoss(self.criterion, self.teacher_model, "hard", 0.5, 1.0, cls_num_list)     
        
        
    def get_tokenized_prompts(self, classnames):
        template = "a photo of a {}."
        prompts = [template.format(c.replace("_", " ")) for c in classnames]
        # print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)
        return prompts

    @torch.no_grad()
    def init_head_text_feat(self):
        cfg = self.cfg
        classnames = self.classnames

        print("Initialize head with text features")
        prompts = self.get_tokenized_prompts(classnames)
        text_features = self.model.encode_text(prompts)
        text_features = F.normalize(text_features, dim=-1)

        if cfg.backbone.startswith("CLIP-ViT"):
            text_features = text_features @ self.model.image_encoder.proj.t()
            text_features = F.normalize(text_features, dim=-1)

        if not isinstance(self.head, nn.ModuleDict):
            self.head.apply_weight(text_features)
        else:
            for item in self.head.values():
                item.apply_weight(text_features)

    @torch.no_grad()
    def init_head_class_mean(self):
        cfg = self.cfg
        print("Initialize head with class means")
        all_features = defaultdict(list)
        all_labels = []
 

        for batch in tqdm(self.train_init_loader, ascii=True):
            image = batch[0]
            label = batch[1]

            image = image.to(self.device)
            label = label.to(self.device)
            
            feature = self.model(image, use_tuner=False, return_feature=True) 
            
            if cfg.ensemble:
                feature = {"head1":feature,        
                           "head2":feature}
            
            if isinstance(feature, dict):
                for key in feature.keys():
                    all_features[key].append(feature[key])
            else:
                all_features["cls"].append(feature)
                
            all_labels.append(label)

        def compute_class_means(features, labels): 
            features = torch.cat(features, dim=0)
            labels = torch.cat(labels, dim=0)
            sorted_index = labels.argsort()
            features = features[sorted_index]
            labels = labels[sorted_index]
            unique_labels, label_counts = torch.unique(labels, return_counts=True)
            class_means = [None] * self.num_classes
            
            idx = 0
            for i, cnt in zip(unique_labels, label_counts):
                class_means[i] = features[idx: idx+cnt].mean(dim=0, keepdim=True)
                idx += cnt
            class_means = torch.cat(class_means, dim=0)
            class_means = F.normalize(class_means, dim=-1)
            
            return class_means
        
        if isinstance(self.head, nn.ModuleDict):     
            for key in feature.keys():
                class_means = compute_class_means(all_features[key], all_labels)
                self.head[key].apply_weight(class_means)

        else: 
            class_means = compute_class_means(all_features["cls"], all_labels)
            self.head.apply_weight(class_means)
        
    @torch.no_grad()
    def init_head_linear_probe(self):
        cfg = self.cfg
        print("Initialize head with linear probing")
        all_features = defaultdict(list)
        all_labels = []

        for batch in tqdm(self.train_init_loader, ascii=True):
            image = batch[0]
            label = batch[1]

            image = image.to(self.device)
            label = label.to(self.device)

            feature = self.model(image, use_tuner=False, return_feature=True)
            
            if cfg.ensemble:
                feature = {"head1":feature,
                           "head2":feature}
                
            if isinstance(feature, dict):
                for key in feature.keys():
                    all_features[key].append(feature[key])
            else:
                all_features["cls"].append(feature)
                
            all_labels.append(label)

        def compute_class_weights(features, labels):
            features = torch.cat(features, dim=0).cpu()
            labels = torch.cat(labels, dim=0).cpu()

            clf = LogisticRegression(solver="lbfgs", max_iter=100, penalty="l2", class_weight="balanced").fit(features, labels)
            class_weights = torch.from_numpy(clf.coef_).to(features.dtype).to(self.device)
            class_weights = F.normalize(class_weights, dim=-1)
            
            return class_weights

        if isinstance(self.head, nn.ModuleDict):
            for key in feature.keys():
                class_means = compute_class_weights(all_features[key], all_labels)
                self.head[key].apply_weight(class_means)     
        else: 
            class_means = compute_class_weights(all_features["cls"], all_labels)
            self.head.apply_weight(class_means)
        

    def train(self):
        cfg = self.cfg

        # Initialize summary writer
        writer_dir = os.path.join(cfg.output_dir, "tensorboard")
        os.makedirs(writer_dir, exist_ok=True)
        print(f"Initialize tensorboard (log_dir={writer_dir})")
        self._writer = SummaryWriter(log_dir=writer_dir)

        # Initialize average meters
        batch_time = AverageMeter()
        data_time = AverageMeter()
        loss_meter = AverageMeter(ema=True)
        acc_meter = AverageMeter(ema=True)
        cls_meters = [AverageMeter(ema=True) for _ in range(self.num_classes)]

        # Remember the starting time (for computing the elapsed time)
        time_start = time.time()

        num_epochs = cfg.num_epochs
        for epoch_idx in range(num_epochs):
            self.tuner.train()
            self.head.train()
                
            end = time.time()

            print('the {0}th epoch'.format(epoch_idx+1))
            print("==================================================================")
            current_lr = self.optim.param_groups[0]["lr"]       
            print("current_lr",current_lr)
            
            if cfg.ensemble:
                assert len(self.train_loader_one) == len(self.train_loader_two)
                num_batches = len(self.train_loader_one)
                for batch_idx, (batch1, batch2) in enumerate(zip(self.train_loader_one, self.train_loader_two)):
                    data_time.update(time.time() - end)
                    image1, image2 = batch1[0], batch2[0]
                    label1, label2 = batch1[1], batch2[1]
                    
                    image1 = image1.to(self.device)
                    image2 = image2.to(self.device)
                    label1 = label1.to(self.device)
                    label2 = label2.to(self.device)
                    

                    with autocast():
                        output1 = self.model(image1)["head1"]
                        output2 = self.model(image2)["head2"]
                        
                        if cfg.cumulative_loss_weight:
                            bbn_weight = 1 - (epoch_idx / (num_epochs - 1)) ** 2
                            loss = bbn_weight * self.criterion_one(output1, label1) + (1 - bbn_weight) * self.criterion_two(output2, label2)
                        else:
                            loss = 0.5 * self.criterion_one(output1, label1) + 0.5 * self.criterion_two(output2, label2)

                        loss_micro = loss / self.accum_step
                        self.scaler.scale(loss_micro).backward()
                    
                    if ((batch_idx + 1) % self.accum_step == 0) or (batch_idx + 1 == num_batches):
                            self.scaler.step(self.optim)
                            self.scaler.update()
                            self.optim.zero_grad()
                    
                    with torch.no_grad():
                        pred1 = output1.argmax(dim=1)
                        pred2 = output2.argmax(dim=1)
                        correct1 = pred1.eq(label1).float()
                        correct2 = pred2.eq(label2).float()
                        acc1 = correct1.mean().mul_(100.0)
                        acc2 = correct2.mean().mul_(100.0)
                        acc = (acc1 + acc2)/2

                    #current_lr = self.optim.param_groups[0]["lr"]
                    loss_meter.update(loss.item())
                    acc_meter.update(acc.item())
                    batch_time.update(time.time() - end)

                    for _c, _y in zip(correct1, label1):
                        cls_meters[_y].update(_c.mul_(100.0).item(), n=1)
                    cls_accs1 = [cls_meters[i].avg for i in range(self.num_classes)]
                    
                    for _c, _y in zip(correct2, label2):
                        cls_meters[_y].update(_c.mul_(100.0).item(), n=1)
                    cls_accs2 = [cls_meters[i].avg for i in range(self.num_classes)]

                    cls_accs = (np.array(cls_accs1) + np.array(cls_accs2)) / 2

                    mean_acc = np.mean(np.array(cls_accs))
                    many_acc = np.mean(np.array(cls_accs)[self.many_idxs])
                    med_acc = np.mean(np.array(cls_accs)[self.med_idxs])
                    few_acc = np.mean(np.array(cls_accs)[self.few_idxs])

                    meet_freq = (batch_idx + 1) % cfg.print_freq == 0
                    only_few_batches = num_batches < cfg.print_freq
                    if meet_freq or only_few_batches:
                        nb_remain = 0
                        nb_remain += num_batches - batch_idx - 1
                        nb_remain += (
                            num_epochs - epoch_idx - 1
                        ) * num_batches
                        eta_seconds = batch_time.avg * nb_remain
                        eta = str(datetime.timedelta(seconds=int(eta_seconds)))

                        info = []
                        info += [f"epoch [{epoch_idx + 1}/{num_epochs}]"]
                        info += [f"batch [{batch_idx + 1}/{num_batches}]"]
                        info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
                        info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
                        info += [f"loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})"]
                        info += [f"acc {acc_meter.val:.4f} ({acc_meter.avg:.4f})"]
                        info += [f"(mean {mean_acc:.4f} many {many_acc:.4f} med {med_acc:.4f} few {few_acc:.4f})"]
                        info += [f"lr {current_lr:.4e}"]
                        info += [f"eta {eta}"]
                        print(" ".join(info))

                    n_iter = epoch_idx * num_batches + batch_idx
                    self._writer.add_scalar("train/lr", current_lr, n_iter)
                    self._writer.add_scalar("train/loss.val", loss_meter.val, n_iter)
                    self._writer.add_scalar("train/loss.avg", loss_meter.avg, n_iter)
                    self._writer.add_scalar("train/acc.val", acc_meter.val, n_iter)
                    self._writer.add_scalar("train/acc.avg", acc_meter.avg, n_iter)
                    self._writer.add_scalar("train/mean_acc", mean_acc, n_iter)
                    self._writer.add_scalar("train/many_acc", many_acc, n_iter)
                    self._writer.add_scalar("train/med_acc", med_acc, n_iter)
                    self._writer.add_scalar("train/few_acc", few_acc, n_iter)
                    
                    end = time.time()
                
            else:
                num_batches = len(self.train_loader)
                for batch_idx, batch in enumerate(self.train_loader):
                    data_time.update(time.time() - end)

                    image = batch[0]    #torch.Size([128, 3, 224, 224])
                    label = batch[1]    #torch.Size([128])
                    
                    image = image.to(self.device)
                    label = label.to(self.device)
                    if cfg.mixup:
                        mix_index = torch.randperm(image.shape[0]).to(self.device)
                        
                        lam = np.random.beta(cfg.alpha, cfg.alpha)
                        mixed_image = lam * image + (1 - lam) * image[mix_index,:,:,:]
                        # print(mixed_image.shape)
                        image = mixed_image

                    if cfg.prec == "amp":
                        with autocast():
                            output = self.model(image)      #
                            if isinstance(output, dict):
                                output_dist = {key: output[key] for key in output.keys() if key != "cls"}
                                output = output["cls"]
                            else:
                                output_dist = None
                            
                            if cfg.teacher:                                            
                                loss = self.criterion(image, (output, output_dist) , label)    
                            else:                                                       

                                if cfg.mixup:
                                    loss = lam * self.criterion(output, label) + (1 - lam) * self.criterion(output, label[mix_index])
                                else:
                                    loss = self.criterion(output, label)
                                
                            loss_micro = loss / self.accum_step
                            self.scaler.scale(loss_micro).backward()
                            
                        if ((batch_idx + 1) % self.accum_step == 0) or (batch_idx + 1 == num_batches):
                            self.scaler.step(self.optim)
                            self.scaler.update()
                            self.optim.zero_grad()
                            
                    else:
                        output = self.model(image)
                        
                        if isinstance(output, dict):
                            output_dist = {key: output[key] for key in output.keys() if key != "cls"}
                            output = output["cls"]
                        else:
                            output_dist = None
                        
                        if cfg.teacher:                                            
                            loss = self.criterion(image, (output, output_dist) , label)    
                        else:                                                      

                            if cfg.mixup:
                                loss = lam * self.criterion(output, label) + (1 - lam) * self.criterion(output, label[mix_index])
                            else:
                                loss = self.criterion(output, label)
                                
                        loss_micro = loss / self.accum_step
                        loss_micro.backward()
                        if ((batch_idx + 1) % self.accum_step == 0) or (batch_idx + 1 == num_batches):
                            self.optim.step()
                            self.optim.zero_grad()

                    with torch.no_grad():
                        pred = output.argmax(dim=1)
                        correct = pred.eq(label).float()
                        acc = correct.mean().mul_(100.0)

                    #current_lr = self.optim.param_groups[0]["lr"]
                    loss_meter.update(loss.item())
                    acc_meter.update(acc.item())
                    batch_time.update(time.time() - end)

                    for _c, _y in zip(correct, label):
                        cls_meters[_y].update(_c.mul_(100.0).item(), n=1)
                    cls_accs = [cls_meters[i].avg for i in range(self.num_classes)]

                    mean_acc = np.mean(np.array(cls_accs))
                    many_acc = np.mean(np.array(cls_accs)[self.many_idxs])
                    med_acc = np.mean(np.array(cls_accs)[self.med_idxs])
                    few_acc = np.mean(np.array(cls_accs)[self.few_idxs])

                    meet_freq = (batch_idx + 1) % cfg.print_freq == 0
                    only_few_batches = num_batches < cfg.print_freq
                    if meet_freq or only_few_batches:
                        nb_remain = 0
                        nb_remain += num_batches - batch_idx - 1
                        nb_remain += (
                            num_epochs - epoch_idx - 1
                        ) * num_batches
                        eta_seconds = batch_time.avg * nb_remain
                        eta = str(datetime.timedelta(seconds=int(eta_seconds)))

                        info = []
                        info += [f"epoch [{epoch_idx + 1}/{num_epochs}]"]
                        info += [f"batch [{batch_idx + 1}/{num_batches}]"]
                        info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
                        info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
                        info += [f"loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})"]
                        info += [f"acc {acc_meter.val:.4f} ({acc_meter.avg:.4f})"]
                        info += [f"(mean {mean_acc:.4f} many {many_acc:.4f} med {med_acc:.4f} few {few_acc:.4f})"]
                        info += [f"lr {current_lr:.4e}"]
                        info += [f"eta {eta}"]
                        print(" ".join(info))

                    n_iter = epoch_idx * num_batches + batch_idx
                    self._writer.add_scalar("train/lr", current_lr, n_iter)
                    self._writer.add_scalar("train/loss.val", loss_meter.val, n_iter)
                    self._writer.add_scalar("train/loss.avg", loss_meter.avg, n_iter)
                    self._writer.add_scalar("train/acc.val", acc_meter.val, n_iter)
                    self._writer.add_scalar("train/acc.avg", acc_meter.avg, n_iter)
                    self._writer.add_scalar("train/mean_acc", mean_acc, n_iter)
                    self._writer.add_scalar("train/many_acc", many_acc, n_iter)
                    self._writer.add_scalar("train/med_acc", med_acc, n_iter)
                    self._writer.add_scalar("train/few_acc", few_acc, n_iter)
                    
                    end = time.time()

            self.sched.step()
            torch.cuda.empty_cache()

        print("Finish training")
        print("Note that the printed training acc is not precise.",
              "To get precise training acc, use option ``test_train True``.")

        # show elapsed time
        elapsed = round(time.time() - time_start)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        print(f"Time elapsed: {elapsed}")

        # save model
        self.save_model(cfg.output_dir)

        self.test(mode="test")

        # Close writer
        self._writer.close()

    @torch.no_grad()
    def test(self, mode="test"):
        if self.tuner is not None:
            self.tuner.eval()
        if self.head is not None:
            self.head.eval()

        self.evaluator.reset()

        if mode == "train":
            print(f"Evaluate on the train set")
            data_loader = self.train_test_loader
        elif mode == "test":
            print(f"Evaluate on the test set")
            data_loader = self.test_loader

        for batch in tqdm(data_loader, ascii=True):
            image = batch[0]
            label = batch[1]

            image = image.to(self.device)
            label = label.to(self.device)

            _bsz, _ncrops, _c, _h, _w = image.size()
            image = image.view(_bsz * _ncrops, _c, _h, _w)
            
            output = self.model(image)
            if isinstance(output, dict):
                if "head1" in output.keys():    
                    output = (output["head1"] + output["head2"]) / 2
                else:                           
                    output_dist = {key: output[key] for key in output.keys() if key != "cls"}
                    output = output["cls"]
            else:
                output_dist = None
                
            if self.cfg.output_mode == "cls_token":                 
                pass                  
            elif self.cfg.output_mode == "dis_token":
                output = sum(output_dist.values()) / len(output_dist)
                    
            elif self.cfg.output_mode == "combined":
                output = (output + sum(output_dist.values()) / len(output_dist) ) / 2

            output = output.view(_bsz, _ncrops, -1).mean(dim=1)
            self.evaluator.process(output, label)

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = f"test/{k}"
            if self._writer is not None:
                self._writer.add_scalar(tag, v)

        return list(results.values())[0]

    def save_model(self, directory):
        tuner_dict = self.tuner.state_dict()
        head_dict = self.head.state_dict()
        
            
        checkpoint = {
            "tuner": tuner_dict,
            "head": head_dict,
            'cfg': self.cfg
        }

        # remove 'module.' in state_dict's keys
        for key in ["tuner", "head"]:
            state_dict = checkpoint[key]
            new_state_dict = OrderedDict()
            if state_dict:                      #head_dist may be None
                for k, v in state_dict.items():
                    if k.startswith("module."):
                        k = k[7:]
                    new_state_dict[k] = v
                checkpoint[key] = new_state_dict

        # save model
        save_path = os.path.join(directory, "checkpoint.pth.tar")
        torch.save(checkpoint, save_path)
        print(f"Checkpoint saved to {save_path}")

    def load_model(self, directory):
        load_path = os.path.join(directory, "checkpoint.pth.tar")

        if not os.path.exists(load_path):
            raise FileNotFoundError('Checkpoint not found at "{}"'.format(load_path))

        checkpoint = torch.load(load_path, map_location=self.device)
        tuner_dict = checkpoint["tuner"]
        head_dict = checkpoint["head"]


        print("Loading weights to from {}".format(load_path))
        self.tuner.load_state_dict(tuner_dict)

        if head_dict["weight"].shape == self.head.weight.shape:
            self.head.load_state_dict(head_dict)
