import time
import torch
import numpy as np
from tqdm import tqdm
import gc
import copy
from ..trainer_ddp import predict
import torch.distributed as dist

def evaluate(config, model, reference_dataloader, query_dataloader, ranks, step_size, cleanup=False):
    model.eval()
    device = config.device
    world_size = config.world_size
    rank = config.rank
    
    
    reference_features, reference_ids = predict(config, model, reference_dataloader)
    
    
    query_features, query_labels = predict(config, model, query_dataloader)
    
    
    if rank == 0:
        print(f"Reference features shape: {reference_features.shape}")
        print(f"Query features shape: {query_features.shape}")
    
    
    dist.barrier()
    
    
    all_reference_features = [torch.zeros_like(reference_features) for _ in range(world_size)]
    all_reference_ids = [torch.zeros_like(reference_ids) for _ in range(world_size)]
    dist.all_gather(all_reference_features, reference_features)
    dist.all_gather(all_reference_ids, reference_ids)
    
    all_query_features = [torch.zeros_like(query_features) for _ in range(world_size)]
    all_query_labels = [torch.zeros_like(query_labels) for _ in range(world_size)]
    dist.all_gather(all_query_features, query_features)
    dist.all_gather(all_query_labels, query_labels)
    
    
    if rank == 0:
        reference_features = torch.cat(all_reference_features)
        reference_ids = torch.cat(all_reference_ids)
        query_features = torch.cat(all_query_features)
        query_labels = torch.cat(all_query_labels)
    
        print("Compute Scores:")
        r1 =  calculate_scores(query_features, reference_features, query_labels, reference_ids, step_size=step_size, ranks=ranks) 

        
        r1_tensor = torch.tensor(r1, device=device)
        dist.broadcast(r1_tensor, src=0)
        return r1_tensor.item()
    else:
        
        r1_tensor = torch.tensor(0.0, device=device)
        dist.broadcast(r1_tensor, src=0)
        return r1_tensor.item()

def calc_sim(config, model, reference_dataloader, query_dataloader, ranks, step_size, cleanup=False):
    model.eval()
    device = config.device
    world_size = config.world_size
    rank = config.rank
    
    
    reference_features, reference_ids = predict(config, model, reference_dataloader)
    
    
    query_features, query_labels = predict(config, model, query_dataloader)
    
    
    if rank == 0:
        print(f"Train Reference features shape: {reference_features.shape}")
        print(f"Train Query features shape: {query_features.shape}")
    
    
    dist.barrier()
    
    
    all_reference_features = [torch.zeros_like(reference_features) for _ in range(world_size)]
    all_reference_ids = [torch.zeros_like(reference_ids) for _ in range(world_size)]
    dist.all_gather(all_reference_features, reference_features)
    dist.all_gather(all_reference_ids, reference_ids)
    
    all_query_features = [torch.zeros_like(query_features) for _ in range(world_size)]
    all_query_labels = [torch.zeros_like(query_labels) for _ in range(world_size)]
    dist.all_gather(all_query_features, query_features)
    dist.all_gather(all_query_labels, query_labels)

    if rank == 0:
        reference_features = torch.cat(all_reference_features)
        reference_labels = torch.cat(all_reference_ids)
        query_features = torch.cat(all_query_features)
        query_labels = torch.cat(all_query_labels)
    
        print("Compute Scores Train:")
        r1 =  calculate_scores_train(query_features, reference_features, query_labels, reference_labels, step_size=step_size, ranks=ranks) 
        
        near_dict = calculate_nearest(query_features=query_features,
                                    reference_features=reference_features,
                                    query_labels=query_labels,
                                    reference_labels=reference_labels,
                                    neighbour_range=config.neighbour_range,
                                    step_size=step_size)
                
        
        if cleanup:
            del reference_features, reference_labels, query_features, query_labels
            gc.collect()
        
        
        import pickle
        near_dict_bytes = pickle.dumps(near_dict)
        near_dict_tensor = torch.ByteTensor(list(near_dict_bytes)).to(device)
        near_dict_size = torch.LongTensor([len(near_dict_bytes)]).to(device)
        
        
        r1_tensor = torch.tensor(r1, device=device)
        dist.broadcast(r1_tensor, src=0)
        
        
        dist.broadcast(near_dict_size, src=0)
        dist.broadcast(near_dict_tensor, src=0)
        
        return r1_tensor.item(), near_dict
    else:
        
        r1_tensor = torch.tensor(0.0, device=device)
        dist.broadcast(r1_tensor, src=0)
        
        
        near_dict_size = torch.LongTensor([0]).to(device)
        dist.broadcast(near_dict_size, src=0)
        
        near_dict_tensor = torch.ByteTensor(near_dict_size.item()).to(device)
        dist.broadcast(near_dict_tensor, src=0)
        
        
        import pickle
        near_dict_bytes = bytes(near_dict_tensor.tolist())
        near_dict = pickle.loads(near_dict_bytes)
        
        return r1_tensor.item(), near_dict


def calculate_scores(query_features, reference_features, query_labels, reference_labels, step_size=1000, ranks=[1,5,10]):

    topk = copy.deepcopy(ranks)
    Q = len(query_features)
    R = len(reference_features)
    
    steps = Q // step_size + 1
    
    
    query_labels_np = query_labels.cpu().numpy()
    reference_labels_np = reference_labels.cpu().numpy()
    
    ref2index = dict()
    for i, idx in enumerate(reference_labels_np):
        ref2index[idx] = i
    
    
    similarity = []
    
    for i in range(steps):
        
        start = step_size * i
        
        end = start + step_size
          
        sim_tmp = query_features[start:end] @ reference_features.T
        
        similarity.append(sim_tmp.cpu())
     
    
    similarity = torch.cat(similarity, dim=0)
    

    topk.append(R//100)
    
    results = np.zeros([len(topk)])
    
    hit_rate = 0.0
    
    bar = tqdm(range(Q))
    
    for i in bar:
        
        
        gt_sim = similarity[i, ref2index[query_labels_np[i][0]]]
        
        
        higher_sim = similarity[i,:] > gt_sim
        
         
        ranking = higher_sim.sum()
        for j, k in enumerate(topk):
            if ranking < k:
                results[j] += 1.
                        
        
        mask = torch.ones(R)
        for near_pos in query_labels_np[i][1:]:
            mask[ref2index[near_pos]] = 0
        
        
        hit = (higher_sim * mask).sum()
        if hit < 1:
            hit_rate += 1.0
                
    
    results = results/ Q * 100.
    hit_rate = hit_rate / Q * 100
    
    bar.close()
    
    
    time.sleep(0.1)
    
    string = []
    for i in range(len(topk)-1):
        
        string.append('Recall@{}: {:.4f}'.format(topk[i], results[i]))
        
    string.append('Recall@top1: {:.4f}'.format(results[-1]))
    string.append('Hit_Rate: {:.4f}'.format(hit_rate))             
        
    print(' - '.join(string)) 

    return results[0]

def calculate_scores_train(query_features, reference_features, query_labels, reference_labels, step_size=1000, ranks=[1,5,10]):

    topk = copy.deepcopy(ranks)
    Q = len(query_features)
    R = len(reference_features)
    
    steps = Q // step_size + 1
    
    query_labels_np = query_labels[:,0].cpu().numpy()
    reference_labels_np = reference_labels.cpu().numpy()
    
    ref2index = dict()
    for i, idx in enumerate(reference_labels_np):
        ref2index[idx] = i
    
    similarity = []
    
    for i in range(steps):
        
        start = step_size * i
        
        end = start + step_size
          
        sim_tmp = query_features[start:end] @ reference_features.T
        
        similarity.append(sim_tmp.cpu())
     
    
    similarity = torch.cat(similarity, dim=0)

    topk.append(R//100)
    
    results = np.zeros([len(topk)])
    
    bar = tqdm(range(Q))
    
    for i in bar:
        
        
        gt_sim = similarity[i, ref2index[query_labels_np[i]]]
        
        
        higher_sim = similarity[i,:] > gt_sim
         
        ranking = higher_sim.sum()
        for j, k in enumerate(topk):
            if ranking < k:
                results[j] += 1.
        
    results = results/ Q * 100.

    bar.close()
    
    
    time.sleep(0.1)
    
    string = []
    for i in range(len(topk)-1):
        
        string.append('Recall@{}: {:.4f}'.format(topk[i], results[i]))
        
    string.append('Recall@top1: {:.4f}'.format(results[-1]))           
        
    print(' - '.join(string)) 

    return results[0]
   

def calculate_nearest(query_features, reference_features, query_labels, reference_labels, neighbour_range=64, step_size=1000):

    query_labels = query_labels[:,0]
    
    Q = len(query_features)
    
    steps = Q // step_size + 1
    
    similarity = []
    
    for i in range(steps):
        
        start = step_size * i
        
        end = start + step_size
          
        sim_tmp = query_features[start:end] @ reference_features.T
        
        similarity.append(sim_tmp.cpu())
     
    
    similarity = torch.cat(similarity, dim=0)


    
    topk_scores, topk_ids = torch.topk(similarity, k=neighbour_range+2, dim=1)


    topk_references = []
    
    for i in range(len(topk_ids)):
        topk_references.append(reference_labels[topk_ids[i,:]])
    
    topk_references = torch.stack(topk_references, dim=0)

     
    
    mask = topk_references != query_labels.unsqueeze(1)
    
    
    topk_references = topk_references.cpu().numpy()
    mask = mask.cpu().numpy()
    

    
    nearest_dict = dict()
    
    for i in range(len(topk_references)):
        
        nearest = topk_references[i][mask[i]][:neighbour_range]
    
        nearest_dict[query_labels[i].item()] = list(nearest)
    

    return nearest_dict
