import os

import time
import numpy as np
import random
from tqdm import tqdm
import torch
import torch.nn as nn

from copy import deepcopy

import utils.architectures as architectures
from utils.architectures import *
from utils.dataset import *
from utils.validation import *
from utils.utils import *

class Reg_distill_DLF(object):
    def __init__(self, args, save_dir, **kwargs):
        self.args = args
        self.measures_name = ['rmse', 'nll']
        for key, value in args.__dict__.items():
            setattr(self, key, value)
        for key, value in kwargs.items():
            setattr(self, key, value)
        self.save_dir = save_dir
        
        self.output_dim = 1
                
        self.train_loader = None
        self.test_loader = None
        self.num_classes = None
        self.optimizer = None
        self.scheduler = None

        self.v1= None
        self.v2= None

        self.best_epoch = {'rmse':0, 'nll':0}
        self.best_valid = {'rmse':np.infty, 'nll':np.infty }
        self.best_state_dict = {'rmse':None, 'nll':None}
        self.best_log = {'rmse':{'train':None, 'valid': None, 'test': None}, 'nll':{'train':None, 'valid': None, 'test': None}}
        self.best_results = {'rmse':{'train':None, 'valid': None, 'test': None}, 'nll':{'train':None, 'valid': None, 'test': None}}


    def _fix_seed(self):
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)
        torch.backends.cudnn.deterministic = True
    

    def _make_loaders(self):
        Loaders, data_info = dataset_reg(self.batch_size, self.dataset, self.data_dir, valid_ratio = self.ratio_valid, seed = self.data_seed, valid_seed = self.data_seed, loader = True, mu_mode = 'zero', loading_mode = 'rbf')
        self.train_loader = Loaders[0]
        self.valid_loader = Loaders[1]
        self.test_loader = Loaders[2]
        self.n_train, self.n_valid, self.n_test, self.input_dim, self.num_classes = data_info 


    def _create_model_t(self, arch, input_dim, h_vec, output_dim, freeze, activation):
        model_factory = architectures.__dict__[arch]
        model_params = dict(input_dim=input_dim, h_vec = h_vec, output_dim = output_dim, activation = activation)

        model = model_factory(**model_params)
        
        if freeze:
            for param in model.parameters():
                param.detach_()
        model = model.cuda()
        return model


    def _define_teacher_model(self):
        self.teacher_model_list = [self._create_model_t(arch = self.arch_t, input_dim = self.input_dim, h_vec = self.teacher_h_vec, output_dim = self.output_dim, freeze = True, activation = self.activation) for i in range(self.num_ens)]
        dir_list = [os.path.join(self.teacher_dir, f"checkpoint_{i}_final.ckpt") for i in range(self.num_ens)]
        
        try : 
            self.sigma_list = None
            sigma_list = []
            for model, tmp_dir in zip(self.teacher_model_list, dir_list):
                checkpoint = torch.load(tmp_dir, map_location = 'cpu')   
                model.load_state_dict(checkpoint['state_dict'])
                model.sigma = checkpoint['sigma'].cuda()
                sigma_list.append(checkpoint['sigma'])
            self.sigma_list = torch.stack(sigma_list, dim = 0).squeeze().cuda()
            self.model_sigma = self.sigma_list.mean().reshape([-1])
            print('Teacher Loading complete!')
        except:
            os.makedirs(self.teacher_dir, exist_ok=True)
            print('Teacher Training')
            self.teacher_model_list = [self._create_model_t(arch = self.arch_t, input_dim = self.input_dim, h_vec = self.teacher_h_vec, output_dim = self.output_dim, freeze = False, activation = self.activation) for i in range(self.num_ens)]
            Loaders, data_info = dataset_reg(self.batch_size, self.dataset, self.data_dir, valid_ratio = 0.0, seed = self.data_seed, valid_seed = self.data_seed, loader = True, mu_mode = 'zero', loading_mode = 'rbf')
            self.train_loader = Loaders[0]
            self.valid_loader = Loaders[1]
            self.test_loader = Loaders[2]
            
            self.teacher_optimizer_list = [torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay =  self.weight_decay) for model in self.teacher_model_list]
            self.teacher_scheduler_list = [torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400], gamma=0.1) for optimizer in self.teacher_optimizer_list]
            
            sigma_list = []
            for num, model, optimizer, scheduler in zip(tqdm(range(self.num_ens)), self.teacher_model_list, self.teacher_optimizer_list, self.teacher_scheduler_list):
                model.train()
                for epoch in range(500):
                    for _, (inputs, targets) in enumerate(self.train_loader):
                        inputs, targets = inputs.cuda(), targets.cuda()
                        outputs = model(inputs)
                        loss_rmse = evaluate_rmse(outputs.squeeze(), targets.squeeze())
                        optimizer.zero_grad()
                        loss_rmse.backward()
                        optimizer.step()
                        
                    scheduler.step()
                    
                model.eval()
                outputs, targets = evaluate_values(model, self.train_loader)
                sigma = ((outputs - targets)**2).mean().sqrt()
                sigma_list.append(sigma)
                model.sigma = sigma.reshape([-1])
                
                save_dict = {'epoch': 500,
                            'state_dict':model.state_dict(),
                            'sigma':model.sigma,
                            'optimizer' : optimizer.state_dict()
                            }
                filename = f'checkpoint_{num}_final.ckpt'
                checkpoint_path = os.path.join(self.teacher_dir, filename)
                torch.save(save_dict, checkpoint_path)
                
                for param in model.parameters():
                    param.detach_()
                model = model.cuda()
            self.sigma_list = torch.stack(sigma_list, dim = 0).squeeze().cuda()
            self.model_sigma = self.sigma_list.mean().reshape([-1])
            self._make_loaders()
            print('Teacher Training : Done!')
            
            
    def _create_model_s(self, arch, input_dim, h_vec, mu_h_vec, phi_h_vec, latent_dim, output_dim, freeze, activation, num_ens = 0):
        model_factory = architectures.__dict__[arch]
        model_params = dict(input_dim=input_dim, h_vec = h_vec, mu_h_vec = mu_h_vec, phi_h_vec = phi_h_vec, latent_dim = latent_dim, output_dim = output_dim, num_ens = num_ens, activation = activation)

        model = model_factory(**model_params)
        
        if freeze:
            for param in model.parameters():
                param.detach_()
        model = model.cuda()

        return model

    def _define_model_and_optimizer(self):
        assert self.arch_s =="MLP_DLF"
        self.student = self._create_model_s(arch = self.arch_s, input_dim = self.input_dim, 
                                            h_vec = self.h_vec, mu_h_vec = self.mu_h_vec, phi_h_vec = self.phi_h_vec, 
                                            latent_dim = self.latent_dim, output_dim = self.output_dim, num_ens = self.num_ens, freeze = False, activation = self.activation)
        self.student.sigma = self.model_sigma
        
        self.optimizer_s_pretrain = torch.optim.Adam(self.student.parameters(), lr=self.pre_lr, weight_decay =  self.weight_decay) 
        self.optimizer_s_mmd = torch.optim.Adam(self.student.parameters(), lr=self.mmd_lr, weight_decay =  self.weight_decay) 
        self.optimizer_s_DLF = torch.optim.Adam(self.student.parameters(), lr=self.lr, weight_decay =  self.weight_decay) 
        self.scheduler_s_DLF = torch.optim.lr_scheduler.MultiStepLR(self.optimizer_s_DLF, milestones=self.lr_schedule, gamma=0.1)

    
    def _calculate_sigma(self, model):
        model.eval()
        outputs, targets = evaluate_values(model, self.train_loader)
        sigma = ((outputs - targets)**2).mean().sqrt()
        model.sigma = sigma.reshape([-1])
        return sigma.reshape([-1])
    

    def _ens_prediction(self, inputs):
        for model in self.teacher_model_list:
            model.eval()
        inputs = inputs.cuda()

        outputs_list = []

        for tmp_model in self.teacher_model_list:
            with torch.no_grad():
                tmp_outputs = tmp_model(inputs)
                outputs_list.append(tmp_outputs)
        
        outputs_list = torch.stack(outputs_list, dim =1)
        return outputs_list


    def _evaluate_values(self, model, loader):
        model.eval()
        
        with torch.no_grad():
            inputs = loader.dataset.tensors[0].cuda()
            targets_list = loader.dataset.tensors[1].cuda()
            
            outputs_s_list = model(inputs)
            outputs_list = outputs_s_list[1]
            outputs_EZ = outputs_s_list[0].mean(1).squeeze()
            outputs_s_list = model(inputs, torch.randn(model.fc.z.shape).cuda())
            outputs_RZ = outputs_s_list[0].mean(1).squeeze()
            
        return outputs_list, targets_list, outputs_EZ, outputs_RZ
    

    def _measures_reg(self, model, loader, measures_name = ['rmse', 'nll', 'crps']):
        mu, y, outputs_EZ, outputs_RZ = self._evaluate_values(model, loader)
        sigma = model.sigma
        
        mu = mu.squeeze()
        y = y.squeeze()
        
        results_dict = {}
        if 'rmse' in measures_name:
            results_dict['rmse'] = evaluate_rmse(mu, y).item()
        if 'nll' in measures_name:
            results_dict['nll'] = evaluate_nll_normal(mu, y, sigma).item()
        if 'crps' in measures_name:
            CRPS_list=[]
            for i in range(len(y)):
                CRPS = evaluate_crps(mu[i], y[i], (sigma**2))
                CRPS_list.append(CRPS)
            CRPSs = torch.stack(CRPS_list)
            results_dict['crps'] = torch.mean(CRPSs).item()
            
        if 'rmse' in measures_name:
            results_dict['rmse_EZ'] = evaluate_rmse(outputs_EZ, y).item()
        if 'nll' in measures_name:
            results_dict['nll_EZ'] = evaluate_nll_normal(outputs_EZ, y, sigma).item()
        if 'crps' in measures_name:
            CRPS_list=[]
            for i in range(len(y)):
                CRPS = evaluate_crps(outputs_EZ[i], y[i], (sigma**2))
                CRPS_list.append(CRPS)
            CRPSs = torch.stack(CRPS_list)
            results_dict['crps_EZ'] = torch.mean(CRPSs).item()
            
        if 'rmse' in measures_name:
            results_dict['rmse_RZ'] = evaluate_rmse(outputs_RZ, y).item()
        if 'nll' in measures_name:
            results_dict['nll_RZ'] = evaluate_nll_normal(outputs_RZ, y, sigma).item()
        if 'crps' in measures_name:
            CRPS_list=[]
            for i in range(len(y)):
                CRPS = evaluate_crps(outputs_RZ[i], y[i], (sigma**2))
                CRPS_list.append(CRPS)
            CRPSs = torch.stack(CRPS_list)
            results_dict['crps_RZ'] = torch.mean(CRPSs).item()

        results_log = ''
        for name, item in results_dict.items():
            results_log += name + ' {:.5f}\t'.format(item)
        
        return results_dict, results_log
    
    
    def _validate_and_save(self, epoch, measures_name, optimizer):
        self.results_train, self.log_train = self._measures_reg(self.student, self.train_loader, measures_name)
        self.results_test, self.log_test = self._measures_reg(self.student, self.test_loader, measures_name)
        self.results_valid, self.log_valid = self._measures_reg(self.student, self.valid_loader, measures_name)
                
        is_best_rmse = (self.results_valid['rmse'] < self.best_valid['rmse'])
        is_best_nll = (self.results_valid['nll'] < self.best_valid['nll'])
        
        if (is_best_rmse | is_best_nll):
            best_results_test, best_log_test = self._measures_reg(self.student, self.test_loader, measures_name)
            if is_best_rmse:
                self.best_epoch['rmse'] = epoch
                self.best_valid['rmse'] = min(self.results_valid['rmse'], self.best_valid['rmse'])
                self.best_log['rmse'] = {'train':self.log_train, 'valid':self.log_valid, 'test':best_log_test}
                self.best_results['rmse'] = {'train':self.results_train, 'valid':self.results_valid, 'test':best_results_test}

                save_dict_rmse = {'epoch': epoch + 1,
                                'state_dict':self.student.state_dict(),
                                'best_rmse_test': best_results_test['rmse'],
                                'optimizer' : optimizer.state_dict()
                                }
                self._save_checkpoint(save_dict_rmse, "best_rmse")
                self.best_state_dict['rmse'] = deepcopy(self.student.state_dict())
            
            if is_best_nll:
                self.best_epoch['nll'] = epoch
                self.best_valid['nll'] = min(self.results_valid['nll'], self.best_valid['nll'])
                self.best_log['nll'] = {'train':self.log_train, 'valid':self.log_valid, 'test':best_log_test}
                self.best_results['nll'] = {'train':self.results_train, 'valid':self.results_valid, 'test':best_results_test}
                
                save_dict_nll = {'epoch': epoch + 1,
                                'state_dict':self.student.state_dict(),
                                'best_nll_test': best_results_test['nll'],
                                'optimizer' : optimizer.state_dict()
                                }
                self._save_checkpoint(save_dict_nll, "best_nll")
                self.best_state_dict['nll'] = deepcopy(self.student.state_dict())
                
                
    #### Pretraining ####################################################################################################################################
    def _pretrain_batch(self, inputs, outputs_t_list, optimizer):
        self.student.train()

        inputs = inputs.cuda()
        outputs_s_list = self.student(inputs)
        mu = outputs_s_list[1]
        y_hats = outputs_t_list.mean(1).cuda()
        
        loss_rmse = evaluate_rmse(mu.squeeze(), y_hats.squeeze())
        optimizer.zero_grad()
        loss_rmse.backward()
        optimizer.step()     
        
        summary = {'loss_rmse': loss_rmse.item()}
        return summary
    
    
    def _pretrain(self):
        print("Pretrain : Start!")
        start = time.time()
        self.pretrain_train_losses = []
        self.pretrain_val_losses = []
        
        for epoch in range(self.pre_epochs):
            meters = AverageMeterSet()
            for _, (inputs, ys) in enumerate(self.train_loader):
                outputs_list = self._ens_prediction(inputs)
                summary = self._pretrain_batch(inputs, outputs_list, self.optimizer_s_pretrain)
                meters.update('loss_rmse', summary['loss_rmse'], inputs.shape[0])
                
            end = time.time()
            print(f'Epoch : {epoch + 1}/{self.pre_epochs}, RMSE Loss : {meters["loss_rmse"].avg:.4f}, Time : {end - start:.2f} sec / {(end - start)/(epoch+1)*(self.pre_epochs+1):.0f} sec', end='\r', flush=True)
        print("Pretrain : Done!")
        
    #### Choice of the initial parameter ################################################################################################################
    def _mmd_train_batch(self, inputs, outputs_t_list, optimizer):
        self.student.train()
        
        inputs = inputs.cuda()
        outputs_s_list = self.student(inputs)

        loss_nll  = nll_loss(outputs_s_list, outputs_t_list) / inputs.shape[0]
        sample_z = torch.randn((self.num_ens, self.latent_dim)).cuda()
        loss_mmd = mmd_loss(self.student.fc.z.squeeze(), sample_z, sigmas=np.array([1.]))

        loss = loss_nll + self.lamb*loss_mmd

        optimizer.zero_grad()
        loss.backward(retain_graph = True)
        torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.)
        optimizer.step()

        summary = {'loss': loss.item(), 'loss_nll':loss_nll.item(), 'loss_MMD':loss_mmd.item()}
        
        return summary
    
    
    def _mmd_train(self):
        print("MMD Train : Start!")
        start = time.time()
        self.mmd_train_losses = []
        self.mmd_val_losses = []
        
        for epoch in range(self.mmd_epochs):
            meters = AverageMeterSet()
            for _, (inputs, ys) in enumerate(self.train_loader):
                outputs_list = self._ens_prediction(inputs)
                summary = self._mmd_train_batch(inputs, outputs_list, self.optimizer_s_mmd)
                mini_batchsize = inputs.shape[0]
                meters.update('loss', summary['loss'], mini_batchsize)
                meters.update('loss_nll', summary['loss_nll'], mini_batchsize)
                meters.update('loss_MMD', summary['loss_MMD'], mini_batchsize)
                
            end = time.time()
            print(f'Epoch : {epoch + 1}/{self.mmd_epochs}, Loss : {meters["loss"].avg:.4f}, NLL : {meters["loss_nll"].avg:.4f}, MMD : {meters["loss_MMD"].avg:.4f}, Time : {end - start:.2f} sec / {(end - start)/(epoch+1)*(self.mmd_epochs+1):.0f} sec', end='\r', flush=True)
        print("MMD Train : Done!")
        
        
    #### Distillation ###################################################################################################################################
    def _EM_loss(self, outputs_s_list, outputs_t_list):
        B = outputs_t_list.shape[0]
        M = outputs_t_list.shape[1]

        mu_f = outputs_s_list[1]
        phi = outputs_s_list[3]
        sigma = outputs_s_list[4]

        q = phi.shape[1]
        I_B = torch.eye(B).to(phi.device)
        I_q = torch.eye(q,q).to(phi.device)

        Sigma_fz = phi
        diff_f = outputs_t_list.squeeze() - mu_f
        Varz = torch.linalg.inv(I_q + Sigma_fz.detach().t() @ Sigma_fz.detach() /(sigma.detach()**2))
        Ez = (Varz @ Sigma_fz.detach().t() @ diff_f.detach() / (sigma.detach()**2))
        
        self.student.fc.z.data = Ez.t().unsqueeze(2).data
        
        Ezzt = []
        for m in range(M):
            Ezzt.append(Varz + Ez[:,m].reshape(-1,1) @ Ez[:,m].reshape(-1,1).t())
        Ezzt = torch.stack(Ezzt, dim = 2)  
        
        Q = torch.zeros(1).cuda()
        first_term = 0; second_term = 0; third_term = 0
        for m in range(M):
            diff_fz_m = (diff_f[:,m]).unsqueeze(1)
            Q += -0.5*(sigma**(-2))*((diff_fz_m.t() @ diff_fz_m).squeeze() 
                                        -2*(Ez[:,m] @ Sigma_fz.t() @ diff_fz_m).squeeze()
                                        + torch.trace(Sigma_fz.t() @ Sigma_fz @ Ezzt[:,:,m]))
            
            first_term += (diff_fz_m.t() @ diff_fz_m).squeeze().item()
            second_term += -2*(Ez[:,m] @ Sigma_fz.t() @ diff_fz_m).squeeze().item()
            third_term += torch.trace(Sigma_fz.t() @ Sigma_fz @ Ezzt[:,:,m]).item()
            
        Q += -0.5*M*B*2*sigma.log()
        Q /= M

        with torch.no_grad():
            rmse_tmp = torch.zeros([1]).cuda()
            logits_s_list = mu_f + phi @ Ez
            for m in range(M):
                rmse_tmp += evaluate_rmse(outputs_t_list[:,m].squeeze(), logits_s_list[:,m].squeeze())
            rmse_tmp /= M 
        
        return -Q
    
    
    def _distill_batch(self, inputs, outputs_t_list, optimizer):
        self.student.train()

        inputs = inputs.cuda()
        outputs_s_list = self.student(inputs)

        loss_EM = self._EM_loss(outputs_s_list, outputs_t_list)
        loss_EM = loss_EM/inputs.shape[0]

        loss = loss_EM

        optimizer.zero_grad()
        loss.backward(retain_graph = True)

        torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.)
        optimizer.step()

        summary = {'loss': loss.item()}
        
        return summary


    def _distill(self):
        print("Train Start")
        start = time.time()
        self.train_losses = []
        self.val_losses = []

        for epoch in range(self.epochs):
            meters = AverageMeterSet()
            for _, (inputs, ys) in enumerate(self.train_loader):
                outputs_t_list = self._ens_prediction(inputs)
                summary = self._distill_batch(inputs, outputs_t_list, self.optimizer_s_DLF)
                mini_batchsize = inputs.shape[0]
                meters.update('loss', summary['loss'], mini_batchsize)
            
            self.scheduler_s_DLF.step()
            if ((epoch+1) % self.evaluation_epochs) == 0:
                self._validate_and_save(epoch, ['rmse', 'nll'], self.optimizer_s_DLF)
            end = time.time()
            print(f'Epoch : {epoch + 1}/{self.epochs}, EM Loss : {meters["loss"].avg:.4f}, Time : {end - start:.2f} sec / {(end - start)/(epoch+1)*(self.epochs+1):.0f} sec', end='\r', flush=True)
        
        self._validate_and_save(epoch, ['rmse', 'nll', 'crps'], self.optimizer_s_DLF)
        self._save_final_checkpoint(self.student, self.optimizer_s_DLF)
        print("Train : Done!")    
        
        
    #### Save #########################################################################################################################################
    def _save_checkpoint(self, state, name):
        filename = 'checkpoint_{}.ckpt'.format(name)
        checkpoint_path = os.path.join(self.save_dir, filename)
        torch.save(state, checkpoint_path)


    def _save_final_checkpoint(self, model, optimizer):
        save_dict = {'epoch': self.epochs,
                    'state_dict':model.state_dict(),
                    'final_test_rmse': self.results_test['rmse'],
                    'final_test_nll': self.results_test['nll'],
                    'final_test_crps': self.results_test['crps'],
                    'optimizer' : optimizer.state_dict()
                    }
        filename = f'checkpoint_final.ckpt'
        checkpoint_path = os.path.join(self.save_dir, filename)
        torch.save(save_dict, checkpoint_path)
        
        self.student.load_state_dict(self.best_state_dict['rmse'])
        self._validate_and_save(self.epochs, ['rmse', 'nll', 'crps'], self.optimizer_s_DLF)
        print(f"Test RMSE : {self.results_test['rmse_RZ']}, Test NLL : {self.results_test['nll_RZ']}, Test CRPS : {self.results_test['crps_RZ']}")
        
        

class Cls_distill_DLF(object):    
    def __init__(self, args, save_dir, **kwargs):
        self.args = args
        for key, value in args.__dict__.items():
            setattr(self, key, value)
        for key, value in kwargs.items():
            setattr(self, key, value)
        self.save_dir = save_dir
    
        self.train_loader = None
        self.test_loader = None
        self.num_classes = None
        self.teacher_list = None
        self.optimizer = None
        self.scheduler = None

        self.best_valid = {'acc':0., 'nll':np.infty }
        self.best_state_dict = {'acc':None, 'nll':None}
        self.best_log = {'acc':{'train':None, 'valid': None, 'test': None}, 'nll':{'train':None, 'valid': None, 'test': None}}
        self.best_results = {'acc':{'train':None, 'valid': None, 'test': None}, 'nll':{'train':None, 'valid': None, 'test': None}}
    
    
    def _fix_seed(self):
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)
        torch.backends.cudnn.deterministic = True
    

    def _make_loaders(self):
        Loaders = batch_dataset(self.batch_size, self.dataset, self.data_dir, transform_valid = "test", augment="standard", n_valid = 10000, num_workers = self.workers)
        self.train_loader = Loaders[0]
        self.valid_loader = Loaders[1]
        self.test_loader = Loaders[2]
        self.num_classes = Loaders[3]
        self.corrupted_set, self.corrupted_data_pth, self.out_test_loader = batch_dataset_robust(self.batch_size, self.dataset, self.data_dir, 'svhn')
        

    def _create_model_t(self, arch, num_classes, droprate, freeze, num_ens = 1):
        model_factory = architectures.__dict__[arch]
        model_params = dict(num_classes=num_classes, droprate = droprate, num_ens = num_ens)
        model = model_factory(**model_params)
        if freeze:
            for param in model.parameters():
                param.detach_()
        model = model.cuda()
        return model


    def _define_teacher_model(self):
        self.teacher_dir = self.teacher_dir[:self.num_ens]
        self.teacher_list = [self._create_model_t(arch = self.arch_t, freeze = True, num_classes = self.num_classes, droprate = self.droprate) for i in range(self.num_ens)]
        
        for tmp_teacher, tmp_dir in zip(self.teacher_list, self.teacher_dir):
            checkpoint = torch.load(os.path.join(tmp_dir, "checkpoint_best_acc.ckpt"), map_location = 'cpu')
            tmp_teacher.load_state_dict(checkpoint['state_dict'])
    

    def _create_model_s(self, arch, num_classes, droprate, freeze, num_ens = 1):
        model_factory = architectures.__dict__[arch]
        model_params = dict(num_classes=num_classes, droprate = droprate, num_ens = num_ens, mu_h_vec=self.mu_h_vec, phi_h_vec=self.phi_h_vec, latent_dim = self.latent_dim)
        model = model_factory(**model_params)
        if freeze:
            for param in model.parameters():
                param.detach_()
        model = model.cuda()
        return model
        

    def _define_model_and_optimizer(self):
        self.model = self._create_model_s(arch = self.arch_s, freeze = False, num_classes = self.num_classes, droprate = self.droprate, num_ens = self.num_ens)
        self.pre_optimizer = torch.optim.SGD(self.model.parameters(), lr=self.pre_lr, weight_decay = self.weight_decay, momentum = 0.9)
        self.mmd_optimizer = torch.optim.SGD(self.model.parameters(), lr=self.mmd_lr, weight_decay = self.weight_decay, momentum = 0.9)
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, weight_decay = self.weight_decay, momentum = 0.9)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.lr_schedule, gamma=0.1)
    

    def _evaluate_probs(self, model_list, loader):
        if isinstance(model_list, list):
            for model in model_list:
                model.eval()

            with torch.no_grad():
                logits_each_list = []
                targets_list = []
                for _, (inputs, targets) in enumerate(loader):
                    inputs, targets = inputs.cuda(), targets.cuda()
                    logits_list = []
                    for model in model_list:
                        logits = model(inputs)
                        logits_list.append(logits)
                    
                    logits_list = torch.stack(logits_list, dim = 1)
                    
                    logits_each_list.append(logits_list)
                    targets_list.append(targets)
                
                logits_each_list = torch.cat(logits_each_list, dim = 0)
                probs_ens_list = nn.functional.softmax(logits_each_list, dim = 2).mean(1)
                
                targets_list = torch.cat(targets_list, dim =0)
        else:
            model = model_list
            model.eval()

            with torch.no_grad():
                
                logits_each_list = []
                targets_list = []
                
                for _, (inputs, targets) in enumerate(loader):

                    inputs, targets = inputs.cuda(), targets.cuda()
                    logits_list, _, _, _, _ = model(inputs, torch.randn(self.n_samples, self.latent_dim, self.num_classes).cuda())
                    logits_each_list.append(logits_list)
                    targets_list.append(targets)
                
                logits_each_list = torch.cat(logits_each_list, dim = 1)
                
                probs_ens_list = nn.functional.softmax(logits_each_list, dim = 2).mean(0)
                targets_list = torch.cat(targets_list, dim = 0)

        return logits_each_list, probs_ens_list, targets_list
    

    def _measures_cls(self, model_list, loader, measures_name = ['acc', 'nll', 'ece']):
        _, probs_list, targets_list = self._evaluate_probs(model_list, loader)
        probs_v_list = None
        targets_v_list = None
        results_dict = evaluate_test_measures(probs_list, targets_list, probs_v_list, targets_v_list, measures_name)
        results_log = ''
        for name, item in results_dict.items():
            results_log += name + ' {:.5f}\t'.format(item)
        return results_dict, results_log


    def _validate_and_save(self, epoch, measures_name):
        self.results_train, self.log_train = self._measures_cls(self.model, self.train_loader, measures_name)
        self.results_valid, log_valid = self._measures_cls(self.model, self.valid_loader, measures_name)
        self.results_test, self.log_test = self._measures_cls(self.model, self.test_loader, measures_name)
        
        is_best_acc = (self.results_valid['acc'] > self.best_valid['acc'])
        is_best_nll = (self.results_valid['nll'] < self.best_valid['nll'])
        
        if (is_best_acc | is_best_nll): 
            if is_best_acc:
                self.best_valid['acc'] = max(self.results_valid['acc'], self.best_valid['acc'])
                self.best_log['acc'] = {'train':self.log_train, 'valid':log_valid, 'test':self.log_test}
                self.best_results['acc'] = {'train':self.results_train, 'valid':self.results_valid, 'test':self.results_test}

                save_dict_acc = {'epoch': epoch + 1,
                                 'state_dict':self.model.state_dict(),
                                 'best_acc_test': self.results_test['acc'],
                                 'optimizer' : self.optimizer.state_dict()
                                 }
                self._save_checkpoint(save_dict_acc, "best_acc")
                self.best_state_dict['acc'] = deepcopy(self.model.state_dict())
                
            if is_best_nll:
                self.best_valid['nll'] = min(self.results_valid['nll'], self.best_valid['nll'])
                self.best_log['nll'] = {'train':self.log_train, 'valid':log_valid, 'test':self.log_test}
                self.best_results['nll'] = {'train':self.results_train, 'valid':self.results_valid, 'test':self.results_test}
                
                save_dict_nll = {'epoch': epoch + 1,
                                 'state_dict':self.model.state_dict(),
                                 'best_nll_test': self.results_test['nll'],
                                 'optimizer' : self.optimizer.state_dict()
                                 }
                self._save_checkpoint(save_dict_nll, "best_nll")
                self.best_state_dict['nll'] = deepcopy(self.model.state_dict())
    
    
    #### Pretraining ####################################################################################################################################
    def _pretrain(self):
        print("Pretrain : Start!")
        pbar = tqdm(range(self.pre_epochs), total = self.pre_epochs)
        for epoch in pbar:
            self.model.train()
            meters = AverageMeterSet()
            for inputs, targets in self.train_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                
                logits_t_list = []
                for teacher in self.teacher_list:
                    with torch.no_grad():
                        logits_t = teacher(inputs)
                    logits_t_list.append(logits_t)
                logits_t_list = torch.stack(logits_t_list, dim=0)
                logits_t_list = torch.mean(logits_t_list, dim=0)

                outputs_list = self.model(inputs)
                mu_f = outputs_list[1]
                
                loss = ((logits_t_list - mu_f)**2).mean().sqrt()

                self.pre_optimizer.zero_grad()
                loss.backward()
                self.pre_optimizer.step()
                
                mini_batchsize = inputs.shape[0]
                meters.update('loss', loss.item(), mini_batchsize)
                
            pbar.set_description(f"Epoch {epoch}/{self.pre_epochs}, Loss: {meters['loss']:.4f}")
            
        print("Pretrain : Done!")
        
    
    #### Choice of the initial parameter ################################################################################################################
    def _mmd_train(self):
        print("MMD Train : Start!")
        
        pbar = tqdm(range(self.mmd_epochs), total = self.mmd_epochs)
        for epoch in pbar:
            self.model.train()
            meters = AverageMeterSet()
            for inputs, targets in self.train_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                
                logits_t_list = []
                for teacher in self.teacher_list:
                    with torch.no_grad():
                        logits_t = teacher(inputs)
                    logits_t_list.append(logits_t)
                logits_t_list = torch.stack(logits_t_list, dim=0)

                outputs_list = self.model(inputs)
                                
                logits_s_list = outputs_list[0]
                sigma_f = outputs_list[4]
                
                M = logits_s_list.shape[1]
                B = logits_s_list.shape[2]
                
                vec_f = logits_t_list.permute(0,2,1).reshape(self.num_ens, -1)
                vec_X = logits_s_list.permute(0,2,1).reshape(self.num_ens, -1)
                diff_X = vec_f - vec_X
                
                loss_nll = torch.trace(diff_X @ diff_X.T) /  ((sigma_f ** 2) * 2)
                loss_nll /= self.num_ens*M*self.num_classes
                
                sample_z = torch.randn((self.num_ens, self.latent_dim * self.num_classes)).cuda()
                loss_mmd = mmd_loss(self.model.fc.z.reshape(self.num_ens, -1), sample_z, sigmas=np.array([0.5,1.,5]))
                
                loss = loss_nll + self.lamb*loss_mmd
                
                self.mmd_optimizer.zero_grad()
                loss.backward()
                self.mmd_optimizer.step()
                
                mini_batchsize = inputs.shape[0]
                meters.update('loss', loss.item(), mini_batchsize)
                meters.update('loss_nll', loss_nll.item(), mini_batchsize)
                meters.update('loss_mmd', loss_mmd.item(), mini_batchsize)
                
            pbar.set_description(f"Epoch {epoch}/{self.mmd_epochs}, Loss: {meters['loss']:.4f}, NLL: {meters['loss_nll']:.4f}, MMD: {meters['loss_mmd']:.4f}")
            
        print("MMD Train : Done!")
        
        
    #### Distillation ###################################################################################################################################
    def _EM_loss(self, outputs_list, logits_t_list):
        M = logits_t_list.shape[0]
        B = logits_t_list.shape[1]
        C = logits_t_list.shape[2]

        mu_f = outputs_list[1]
        phi = outputs_list[3]
        sigma = outputs_list[4]
        L_c = self.model.fc.L_c()
        p = phi.shape[1]
        
        vec_f = logits_t_list.permute(0,2,1).reshape(self.num_ens, -1)
        vec_mu_f = mu_f.permute(1,0).reshape(-1).unsqueeze(0)
        
        diff_f = vec_f - vec_mu_f
        Sigma_fz = torch.kron(L_c, phi)
        Sigma_z_f = torch.linalg.inv(torch.eye(p*C).cuda() + (sigma**(-2)) * Sigma_fz.T @ Sigma_fz)
        
        Ez = Sigma_z_f @ Sigma_fz.T @ diff_f.t() / (sigma**2)
        Ez = Ez.detach()
        self.model.fc.z.data = Ez.T.reshape(M,C,p).permute(0,2,1)

        Ezzt = []
        for m in range(M):
            Ezzt.append((Sigma_z_f +  Ez[:,m].unsqueeze(1) @ Ez[:,m].unsqueeze(0)).detach())
        Ezzt = torch.stack(Ezzt, dim = 2)
        
        Q = torch.zeros(1).cuda()
        first_term = 0; second_term = 0; third_term = 0
        for m in range(M):
            diff_f_m = diff_f[m,:].unsqueeze(1)
            Q += -0.5*(sigma**(-2))*((diff_f_m.t() @ diff_f_m).squeeze()
                                    -2*(Ez[:,m] @ Sigma_fz.t() @ diff_f_m).squeeze()
                                    + torch.trace(Sigma_fz.t() @ Sigma_fz @ Ezzt[:,:,m]))
            
            first_term += (diff_f_m.t() @ diff_f_m).squeeze().item()
            second_term += -2*(Ez[:,m] @ Sigma_fz.t() @ diff_f_m).squeeze().item()
            third_term += torch.trace(Sigma_fz.t() @ Sigma_fz @ Ezzt[:,:,m]).item()
        Q /= M*B*C
        
        Q += -0.5*2*sigma.log()

        return -Q

    def _distill(self):
        print("Distill : Start!")
        pbar = tqdm(range(self.epochs), total = self.epochs)
        for epoch in pbar:
            self.model.train()
            meters = AverageMeterSet()
            for inputs, targets in self.train_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                
                logits_t_list = []
                for teacher in self.teacher_list:
                    with torch.no_grad():
                        logits_t = teacher(inputs)
                    logits_t_list.append(logits_t)
                logits_t_list = torch.stack(logits_t_list, dim=0)
                
                outputs_list = self.model(inputs)
                logits_s_list = outputs_list[0]
                probs_s_list = nn.functional.softmax(logits_s_list, dim=2)
                    
                loss = self._EM_loss(outputs_list, logits_t_list)

                _, predicted = torch.max(probs_s_list.mean(0).data, 1)
                correct = predicted.eq(targets).cpu().sum().float()

                mini_batchsize = len(inputs)
                acc = correct/mini_batchsize
                
                meters.update('loss', loss.item(), mini_batchsize)
                meters.update('acc', acc.item(), mini_batchsize)
                
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.)
                self.optimizer.step()
                
            pbar.set_description(f"Epoch {epoch}/{self.epochs}, Loss: {meters['loss']:.4f}, Accuracy: {meters['acc']:.4f}")
            self.scheduler.step()
            
            if ((epoch+1) % self.evaluation_epochs) == 0:
                self._validate_and_save(epoch, ['acc', 'nll'])
        
        self._validate_and_save(epoch, ['acc', 'nll','ece'])            
        self._save_final_checkpoint(self.model, self.optimizer)
        print("Train : Done!")    

    
    #### Save #########################################################################################################################################
    def _save_checkpoint(self, state, name):
        filename = 'checkpoint_{}.ckpt'.format(name)
        checkpoint_path = os.path.join(self.save_dir, filename)
        torch.save(state, checkpoint_path)


    def _save_final_checkpoint(self, model, optimizer):
        save_dict = {'epoch': self.epochs,
                    'state_dict':model.state_dict(),
                    'final_test_acc': self.results_test['acc'],
                    'final_test_nll': self.results_test['nll'],
                    'final_test_ece': self.results_test['ece'],
                    'optimizer' : optimizer.state_dict()
                    }
        filename = f'checkpoint_final.ckpt'
        checkpoint_path = os.path.join(self.save_dir, filename)
        
        print(f"Test Acc : {self.results_test['acc']}, Test NLL : {self.results_test['nll']}, Test ECE : {self.results_test['ece']}")
        torch.save(save_dict, checkpoint_path)