
import os
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from utils import *
import time
import numpy as np
# import clip
from clip import clip
from classes import *
from trainer import *
import ot

eplisons = 0.1
fea_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8)

def load_clip_to_cpu(visual_backbone):
    backbone_name = visual_backbone
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url, os.path.expanduser("~/.cache/clip"))

    model = clip.load(model_path, 'cuda')

    return model[0]

class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype
        self.token_embedding = clip_model.token_embedding

    def forward(self, prompts, tokenized_prompts=None):
        
        if tokenized_prompts != None:
            x = prompts + self.positional_embedding.type(self.dtype)
            index = tokenized_prompts.argmax(dim=-1)
        else:
            x = self.token_embedding(prompts).type(self.dtype)  # [batch_size, n_ctx, d_model]
            x += self.positional_embedding.type(self.dtype)
            index = prompts.argmax(dim=-1)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        
        x = x[torch.arange(x.shape[0]), index] @ self.text_projection

        return x

class model ():
    
    def __init__(self, config, data, test=True):
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.config = config
        self.training_opt = self.config['training_opt']
        self.model_opt = self.config['model']
        self.data = data
        self.test_mode = True
        self.num_gpus = torch.cuda.device_count()
        self.do_shuffle = config['shuffle'] if 'shuffle' in config else False
        self.clip_model = load_clip_to_cpu(self.model_opt['clip']['params']['visual_backbone'])
        
        # Initialize model
        self.init_models()


    def init_models(self, optimizer=True):
        self.model_optim_params_list = []

        print("Using", torch.cuda.device_count(), "GPUs.")

        self.visual_model = torch.nn.DataParallel(self.clip_model.visual).cuda()
        text_model = TextEncoder(self.clip_model)

        self.text_model = torch.nn.DataParallel(text_model).cuda()

        feat_dim = self.model_opt['adapter']['params']['feat_dim']

        self.adapter = torch.nn.DataParallel(nn.Sequential(
            nn.Linear(feat_dim, feat_dim, bias=False),
            )).cuda()
        
        self.fc = nn.Linear(feat_dim, self.config['training_opt']['num_classes'], bias=False).cuda()
                
        if self.training_opt['phaseA'] is not True:
            self.load_model(self.config['model_dir'])

            for param_name, param in self.visual_model.named_parameters():
                param.requires_grad = False

            for param_name, param in self.text_model.named_parameters():
                param.requires_grad = False

            for param_name, param in self.clip_model.named_parameters():
                param.requires_grad = False

            for param_name, param in self.adapter.named_parameters():
                param.requires_grad = False
                
            for param_name, param in self.fc.named_parameters():
                param.requires_grad = False

            self.clip_model.eval()
            self.visual_model.eval()
            self.text_model.eval()
            self.adapter.eval()
            self.fc.eval()

    def batch_forward(self, inputs, phase='train', indexes=None, labels=None):
        '''
        This is a general single batch running function. 
        '''

        meta_classes = obtain_labels(self.config["training_opt"]["dataset"])
        classesname, templates = meta_classes["CLASSES"], meta_classes["CUSTOM_TEMPLATES"]
        
        multi_prompt = prompt_templates

        if self.model_opt['clip']['params']['visual_backbone'].startswith("ViT"):
            inputs = inputs.to(torch.float16)

        _bs, _shot = inputs.shape[0], inputs.shape[1]
        inputs = inputs.view(_bs * _shot, 3, 224, 224)
        image_features = self.visual_model(inputs).float()
        x = image_features

        ratio = 0.2
        outputs = self.adapter(x)
        outputs = ratio * outputs + (1-ratio) * x
        outputs = F.normalize(outputs)
        self.logits = self.fc(outputs)
        self.logits = self.logits.view(_bs, _shot, -1).mean(dim=1).squeeze()

        self.ori_features = image_features
        self.adapted_features = outputs

    def shuffle_batch(self, x, y):
        index = torch.randperm(x.size(0))
        x = x[index]
        y = y[index]
        return x, y

    def train(self):
        # When training the network
        print_str = ['Phase: test']
        
        time.sleep(0.25)

        print('Start Test!')
        rsls_eval = self.eval(phase='test')
        print(rsls_eval)
        print('Done')

    def eval(self, phase='test', openset=False, save_feat=False):

        print_str = ['Phase: %s' % (phase)]
        time.sleep(0.25)
 
        torch.cuda.empty_cache()

        self.total_logits = torch.empty((0, self.training_opt['num_classes'])).cuda()
        self.total_labels = torch.empty(0, dtype=torch.long).cuda()
        self.total_paths = np.empty(0)

        get_feat_only = save_feat
        feats_all, labels_all, idxs_all, logits_all = [], [], [], []
        adapted_all = []
        featmaps_all = []
        
        # Iterate over dataset
        for inputs, labels, paths in tqdm(self.data[phase]):
            inputs, labels = inputs.cuda(), labels.cuda()

            # If on training phase, enable gradients
            with torch.set_grad_enabled(False):

                # In validation or testing
                self.batch_forward(inputs, phase=phase)
                
                logits_all.append(self.logits.cpu().numpy())
                feats_all.append(self.ori_features.cpu().numpy())
                adapted_all.append(self.adapted_features.cpu().numpy())
                labels_all.append(labels.cpu().numpy())
                idxs_all.append(paths.numpy())

        typ = 'feat'
        if phase == 'train_plain':
            name = 'train{}_all.pkl'.format(typ)
        elif phase == 'test':
            name = 'test{}_all.pkl'.format(typ)
        elif phase == 'val':
            name = 'val{}_all.pkl'.format(typ)

        fname = os.path.join(self.training_opt['log_dir'], name)
        print('===> Saving feats to ' + fname)
        
        with open(fname, 'wb') as f:
            pickle.dump({
                        'logits': np.concatenate(logits_all),
                            'feats': np.concatenate(feats_all),
                            'adapted_features': np.concatenate(adapted_all),
                            'labels': np.concatenate(labels_all),
                            'idxs': np.concatenate(idxs_all),
                        },
                        f, protocol=4) 
        
        probs, preds = F.softmax(torch.from_numpy(np.concatenate(logits_all)), dim=1).max(dim=1)
        self.total_labels = torch.from_numpy(np.concatenate(labels_all))

        # Calculate the overall accuracy and F measurement
        self.eval_acc_mic_top1= mic_acc_cal(preds[self.total_labels != -1],
                                            self.total_labels[self.total_labels != -1])
        self.eval_f_measure = F_measure(preds, self.total_labels, openset=openset,
                                        theta=self.training_opt['open_threshold'])
        self.many_acc_top1, \
        self.median_acc_top1, \
        self.low_acc_top1, \
        self.cls_accs = shot_acc(preds[self.total_labels != -1],
                                 self.total_labels[self.total_labels != -1], 
                                 self.data['train'],
                                 acc_per_cls=True)
        # Top-1 accuracy and additional string
        print_str = ['\n\n',
                     'Phase: %s' 
                     % (phase),
                     '\n\n',
                     'Evaluation_accuracy_micro_top1: %.4f\t' 
                     % (self.eval_acc_mic_top1),
                     '\n',
                     'Averaged F-measure: %.4f\t' 
                     % (self.eval_f_measure),
                     '\n',
                     'Many_shot_accuracy_top1: %.4f\t' 
                     % (self.many_acc_top1),
                     'Median_shot_accuracy_top1: %.4f\t' 
                     % (self.median_acc_top1),
                     'Low_shot_accuracy_top1: %.4f' 
                     % (self.low_acc_top1),
                     '\n']
        
        rsl = {phase + '_all': self.eval_acc_mic_top1,
               phase + '_many': self.many_acc_top1,
               phase + '_median': self.median_acc_top1,
               phase + '_low': self.low_acc_top1,
               phase + '_fscore': self.eval_f_measure}

        if phase == 'test':
            with open(os.path.join(self.training_opt['log_dir'], 'cls_accs.pkl'), 'wb') as f:
                pickle.dump(self.cls_accs, f)

        print(''.join(print_str))
        return rsl

    def load_model(self, model_dir=None):
        model_dir = self.training_opt['log_dir'] if model_dir is None else model_dir
        if not model_dir.endswith('.pth'):
            print('No pretrained Phase A model')
        
        print('Validation on the best model.')
        print('Loading model from %s' % (model_dir))
        
        checkpoint = torch.load(model_dir, map_location='cpu')     

        model_state = checkpoint['state_dict_best']            
        self.text_model.load_state_dict(model_state['text_model'])
        self.visual_model.load_state_dict(model_state['visual_model'])
        self.adapter.load_state_dict(model_state['adapter'])
        self.fc.load_state_dict(model_state['classifier'])
    
    def save_model(self, epoch, best_epoch, best_model_weights, best_acc, centroids=None):
        
        model_states = {'epoch': epoch,
                'best_epoch': best_epoch,
                'state_dict_best': best_model_weights,
                'best_acc': best_acc,
                'centroids': centroids}

        model_dir = os.path.join(self.training_opt['log_dir'], 
                                 'final_model_checkpoint.pth')

        torch.save(model_states, model_dir)
            
    def output_logits(self, openset=False):
        filename = os.path.join(self.training_opt['log_dir'], 
                                'logits_%s'%('open' if openset else 'close'))
        print("Saving total logits to: %s.npz" % filename)
        np.savez(filename, 
                 logits=self.total_logits.detach().cpu().numpy(), 
                 labels=self.total_labels.detach().cpu().numpy(),
                 paths=self.total_paths)

