import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from helpers import list_of_distances
import pdb

def _train_or_test(args, model, dataloader, optimizer=None, class_specific=True, use_l1_mask=True,
                   coefs=None, log=print):
    '''
    model: the multi-gpu model
    dataloader:
    optimizer: if None, will be test evaluation
    '''
    is_train = optimizer is not None
    start = time.time()
    n_examples = 0
    n_correct = 0
    n_batches = 0
    total_cross_entropy = 0
    vl_total_cluster_cost = 0
    # separation cost is meaningful only for class_specific
    vl_total_separation_cost = 0
    vl_total_avg_separation_cost = 0


    for i,  (image_input, target, language_input, token_type, input_mask, nwords) in enumerate(dataloader):
        
        image_input = image_input.cuda()
        target = target.cuda()
        language_input = language_input.cuda()
        token_type = token_type.cuda()
        input_mask = input_mask.cuda()
        nwords = nwords.cuda()

        # torch.enable_grad() has no effect outside of no_grad()
        grad_req = torch.enable_grad() if is_train else torch.no_grad()
        with grad_req:
            # nn.Module has implemented __call__() function
            # so no need to call .forward
            #output, min_distances = model(input)
            # compute loss
            #cross_entropy = torch.nn.functional.cross_entropy(output, target)
            vl_output, vl_min_distances = model(image_input,language_input, token_type, input_mask, nwords)
            output = vl_output
            # compute loss
            
            cross_entropy = torch.nn.functional.cross_entropy(output, target)
            # nn.Module has implemented __call__() function
            # so no need to call .forward

            #vision stream
            max_dist = (model.module.prototype_shape[1]
                        * model.module.prototype_shape[2]
                        * model.module.prototype_shape[3])# 512*1*1   512*14*14

            # prototypes_of_correct_class is a tensor of shape batch_size * num_prototypes
            # calculate cluster cost
            prototypes_of_correct_class = torch.t(model.module.prototype_class_identity[:,target]).cuda()
            inverted_distances, _ = torch.max((max_dist - vl_min_distances) * prototypes_of_correct_class, dim=1)
            vl_cluster_cost = torch.mean(max_dist - inverted_distances)

            # calculate separation cost
            prototypes_of_wrong_class = 1 - prototypes_of_correct_class
            inverted_distances_to_nontarget_prototypes, _ = \
                torch.max((max_dist - vl_min_distances) * prototypes_of_wrong_class, dim=1)
            vl_separation_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes)

            # calculate avg cluster cost
            vl_avg_separation_cost = \
                torch.sum(vl_min_distances * prototypes_of_wrong_class, dim=1) / torch.sum(prototypes_of_wrong_class, dim=1)
            vl_avg_separation_cost = torch.mean(vl_avg_separation_cost)
            
            if use_l1_mask:
                l1_mask = 1 - torch.t(model.module.prototype_class_identity).cuda()
                vl_l1 = (model.module.last_layer.weight * l1_mask).norm(p=1)

            else:
                vl_l1 = model.module.last_layer.weight.norm(p=1) 

        


            # evaluation statistics
            _, predicted = torch.max(output.data, 1)
            n_examples += target.size(0)
            n_correct += (predicted == target).sum().item()

            n_batches += 1
            total_cross_entropy += cross_entropy.item()
            vl_total_cluster_cost += vl_cluster_cost.item()
            
            vl_total_separation_cost += vl_separation_cost.item()
            
            vl_total_avg_separation_cost += vl_avg_separation_cost.item()
    

        # compute gradient and do SGD step
        if is_train:
            loss = (args.coeff_ce * cross_entropy
                    + args.coeff_clst * vl_cluster_cost
                    + args.coeff_sep * vl_separation_cost
                    + args.coeff_l1 * vl_l1)
                    

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if model.module.mode == 1 and (i+1) % args.update_projector_freq == 0:
                model.module.update_rotation_matrix() 
            

    end = time.time()

    log('\ttime: \t{0}'.format(end -  start))
    log('\tcross ent: \t{0}'.format(total_cross_entropy / n_batches))
    
    log('\tvl_cluster: \t{0}'.format(vl_total_cluster_cost / n_batches))
    log('\tvl_separation:\t{0}'.format(vl_total_separation_cost / n_batches))
    log('\tvl_avg separation:\t{0}'.format(vl_total_avg_separation_cost / n_batches))

    log('\taccu: \t\t{0}%'.format(n_correct / n_examples * 100))
    log('\tvl_l1: \t\t{0}'.format(model.module.last_layer.weight.norm(p=1).item()))
    
    p = model.module.prototype_vectors.view(model.module.num_prototypes, -1).cpu()
    with torch.no_grad():
        p_avg_pair_dist = torch.mean(list_of_distances(p, p))
    log('\tp dist pair: \t{0}'.format(p_avg_pair_dist.item()))

    return n_correct / n_examples


def train(args, model, dataloader, optimizer, class_specific=False, coefs=None, log=print):
    assert(optimizer is not None)
    
    log('\ttrain')
    model.train()
    return _train_or_test(args=args,model=model, dataloader=dataloader, optimizer=optimizer,
                          class_specific=class_specific, coefs=coefs, log=log)


def test(args, model, dataloader, class_specific=False, coefs=None,log=print):
    log('\ttest')
    model.eval()
    return _train_or_test(args=args,model=model, dataloader=dataloader, optimizer=None,
                          class_specific=class_specific,coefs=coefs, log=log)


def last_only(model, log=print):
    for p in model.module.image_model.parameters():
        p.requires_grad = False
    for p in model.module.language_model.parameters():
        p.requires_grad = False
    #add on layers
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = False
    for p in model.module.language_projection_head.parameters():
        p.requires_grad = False
    #prototype vector
    model.module.prototype_vectors.requires_grad = False
    #last layers
    for p in model.module.last_layer.parameters():
        p.requires_grad = True

    
    log('\tlast layer')


def warm_only(model, log=print):
    for p in model.module.image_model.parameters():
        p.requires_grad = False
    for p in model.module.language_model.parameters():
        p.requires_grad = False
    #add on layers
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = True
    for p in model.module.language_projection_head.parameters():
        p.requires_grad = True
    #prototype vector
    model.module.prototype_vectors.requires_grad = True
    #last layers
    for p in model.module.last_layer.parameters():
        p.requires_grad = True

    
    log('\twarm')
def warm_only_vision(model, log=print):
    for p in model.module.image_model.parameters():
        p.requires_grad = False
    for p in model.module.language_model.parameters():
        p.requires_grad = False
    #add on layers
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = True
    for p in model.module.language_projection_head.parameters():
        p.requires_grad = False
    #prototype vector
    model.module.prototype_vectors.requires_grad = True
    #last layers
    for p in model.module.last_layer_v.parameters():
        p.requires_grad = True
    for p in model.module.last_layer_l.parameters():
        p.requires_grad = False
    
    log('\twarm vision')

def joint(model, log=print):
    for p in model.module.image_model.parameters():
        p.requires_grad = True
    for p in model.module.language_model.parameters():
        p.requires_grad = True
    #add on layers
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = True
    for p in model.module.language_projection_head.parameters():
        p.requires_grad = True
    #prototype vector
    model.module.prototype_vectors.requires_grad = True
    #last layers
    for p in model.module.last_layer.parameters():
        p.requires_grad = True

    
    log('\tjoint')

def joint_vision(model, log=print):
    for p in model.module.image_model.parameters():
        p.requires_grad = True
    for p in model.module.language_model.parameters():
        p.requires_grad = False
    #add on layers
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = True
    for p in model.module.language_projection_head.parameters():
        p.requires_grad = False
    #prototype vector
    model.module.prototype_vectors.requires_grad = True
    #last layers
    for p in model.module.last_layer_v.parameters():
        p.requires_grad = True
    for p in model.module.last_layer_l.parameters():
        p.requires_grad = False
    
    log('\tjoint vision')



