import torch
import torch as ch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import model_editing.helpers.context_helpers as coh
from model_editing.tools import nethook


def downscale_mask(mask, tgt_size, threshold=None):
    src_size = mask.shape[-1]
    factor = src_size // tgt_size
    assert src_size == factor * tgt_size
    pooled = F.avg_pool2d(mask, factor, stride=factor)
    if threshold is not None:
        return pooled > threshold
    else:
        return pooled


def target_weights(target_model):
    return [p for n, p in target_model.named_parameters() if 'weight' in n][0]


# (1, 64, 32, 32) x (1, 64) -> (1, 1, 32, 32)
# (1, 1, 32, 32) x (1, 64) -> (1, 64, 32, 32)
# original 1
def projected_conv(weight, direction, unfold=False):
    if len(weight.shape) == 5:
        cosine_map = torch.einsum('goiyx, di -> godyx', weight, direction)
        result = torch.einsum('godyx, di -> goiyx', cosine_map, direction)
    else:
        if unfold:
            direction_r = direction.unsqueeze(2).unsqueeze(3)
            direction_r = direction_r.reshape(direction_r.shape[0],
                                              weight.shape[1],
                                              weight.shape[2],
                                              weight.shape[3]).transpose(2, 3)
            cosine_map = torch.einsum('oiyx, diyx -> od', weight, direction_r)
            result = torch.einsum('od, diyx -> oiyx', cosine_map, direction_r)
        else:  # unfold=False
            cosine_map = torch.einsum('oiyx, di -> odyx', weight, direction)
            result = torch.einsum('odyx, di -> oiyx', cosine_map, direction)
    return result


def edit_classifier_weights(target_model, key, val, context,
                            niter=2001, piter=10, lr=0.05,
                            low_rank_insert=True, low_rank_gradient=False,
                            unfold=False, mask=None):

    def update_callback(it, loss, pbar=None):
        if it % 50 == 0 or it == niter - 1:
            loss_info = (f'lr {lr:.4f}\titer {it: 6d}/{niter: 6d}'
                         f'\tloss {loss.item():.4f}')
            if pbar:
                pbar.set_description(str(loss))
            else:
                print(loss_info)
    try:
        key, val = [d.detach() for d in [key, val]]
    except:
        val = val.detach()

    def compute_loss(mask=None):
        reps = val, target_model(key)  # val of clean imgs & vals of modified imgs
        if mask is not None:
            mask = downscale_mask(mask, val.shape[-1], None)
            mask = mask.sqrt()
            reps = [r * mask for r in reps]
        return torch.nn.functional.l1_loss(*reps) / len(val)

    # set up optimizer
    weight = target_weights(target_model)
    weight_orig = weight.clone()
    params = [weight]
    if low_rank_insert or low_rank_gradient:
        with torch.no_grad():
            ortho_weight = weight - projected_conv(weight, context, unfold=unfold)
    optimizer = torch.optim.Adam(params, lr=lr)

    pbar = tqdm(range(niter))
    for it in pbar:
        with torch.enable_grad():
            # ======= $$$ try to break the correlation with negative loss ========
            # if mask is None:
            loss = compute_loss(mask)
            # else:
            #     loss = -compute_loss(mask)
            optimizer.zero_grad()
            loss.backward()

            if it == 0: loss_orig = loss.item()

            # update grad
            if low_rank_gradient:
                weight.grad[...] = projected_conv(weight.grad, context, unfold=unfold)

            ####
            optimizer.step()
            ####

            if update_callback is not None:
                update_callback(it, loss, pbar=pbar)

            if low_rank_insert and (it % piter == 0 or it == niter - 1):
                with torch.no_grad():
                    # update weight
                    weight[...] = (ortho_weight + projected_conv(weight, context, unfold=unfold))

    print("Loss (orig, final):", loss_orig, loss.item())
    print("L2 norm of weight change:", ch.norm(weight_orig - weight).item())


def context_cache(args, context_model,
                  val_loader=None,
                  caching_dir=None):

    # context_helpers
    # $$$ cov matrices + zca whitening
    _, ZM_k = coh.get_cov_matrix(val_loader, context_model,
                                 batch_size=200,
                                 key_method='zca',
                                 caching_dir=caching_dir,
                                 layer_name=args.layer_name)


def edit_classifier(args, train_data,
                    context_model,
                    target_model=None,
                    val_loader=None,
                    caching_dir=None):

    assert args.ntrain <= len(train_data['imgs'])
    cp_imgs = ch.cat([train_data['imgs'][:args.ntrain], train_data['modified_imgs'][:args.ntrain]]).float()
    cp_masks = ch.cat([train_data['masks'][:args.ntrain], train_data['masks'][:args.ntrain]]).float()
    
    Nims = len(cp_imgs)

    if args.mode_rewrite == 'editing':
        # context_helpers
        # $$$ cov matrices + zca whitening
        _, ZM_k = coh.get_cov_matrix(val_loader, context_model,
                                     batch_size=500,
                                     key_method='zca',
                                     caching_dir=caching_dir,
                                     layer_name=args.layer_name)

        assert (target_model is not None) and (ZM_k is not None)
        # $$$ gather context_key by indexing (modified keys) from (cleansed) ZM_k
        context_k = coh.get_context_key(train_data['modified_imgs'].float(), 
                                        train_data['masks'], 
                                        context_model, ZM_k, rank=args.rank,
                                        layer_name=args.layer_name)

        # print(torch.norm(context_k, p=2))
        # print(torch.norm(context_k, p=1))
        # print('context key')
        # return context_model

        with ch.no_grad():
            context_model(cp_imgs.cuda())

        kstar = coh.features[args.layer_name + '_pre']
        kstar = kstar[Nims//2:].detach().clone()
        vstar = coh.features[args.layer_name + '_post'][:Nims//2].detach().clone()

        mstar = ch.max(cp_masks[:Nims//2], dim=1, keepdims=True)[0]

        edit_classifier_weights(target_model, kstar, vstar,
                                context_k, niter=args.nsteps,
                                piter=args.nsteps_proj, lr=args.lr,
                                low_rank_insert=args.restrict_rank,
                                mask=mstar.cuda() if args.use_mask else None)
    else:
        if args.arch == 'resnet50':
            first_layer = f'layer{args.layernum + 1}.final.conv3'  
        elif args.arch == 'resnet18':
            first_layer = f'layer{args.layernum + 1}.final.conv2'  
        elif args.arch.startswith('vgg'):
            first_layer = f'layer{args.layernum}.conv'
        else:
            first_layer = f'visual.layer{args.layernum + 1}.final'
                
        if args.mode_rewrite == 'finetune_local':
            edit_params = [target_weights(target_model)]
        else:
            edit_model = nethook.subsequence(context_model,
                                             first_layer=first_layer,
                                             share_weights=True)

            edit_params = edit_model.parameters()
            
            if args.arch.startswith('clip'):
                edit_params = []
                for name, param in edit_model.named_parameters():
                    if 'visual' in name:
                        edit_params.append(param)

        optimizer = ch.optim.SGD(edit_params, lr=args.lr)
        compute_loss = torch.nn.CrossEntropyLoss()
        pbar = tqdm(range(args.nsteps))
        
        imgs = train_data['modified_imgs'][:args.ntrain].float()
        target_label = np.unique(train_data['labels'][:args.ntrain].numpy())
        assert len(target_label) == 1
        
        tgts = ch.tensor([target_label[0]] * len(imgs))
        
        with torch.enable_grad():
            for i in pbar:
                loss = compute_loss(context_model(imgs.cuda()), tgts.cuda())
                optimizer.zero_grad()
                loss.backward()
                pbar.set_description(str(loss))
                optimizer.step()
        loss.detach()
       
    return context_model


def layer_locate(args, input_data, context_model):

    from explaining_tool import IntegratedGradientsWithRef
    # logit prob contrast
    ig_grad = IntegratedGradientsWithRef(context_model, k=args.k, exp_obj='logit', dataset_name=args.dataset)
    internal_attr_change = ig_grad.shap_values(input_tensor=input_data['imgs'].cuda(),
                                               ref_tensor=input_data['ref_imgs'].cuda(),
                                               sparse_labels=input_data['labels'].cuda())

    del_name = []
    for key, value in internal_attr_change.items():
        if ('conv' in key) or ('fc' in key) or ('bn' in key) or ('relu' in key) or ('pool' in key):
            print('deleting layer {}'.format(key))
            del_name.append(key)

        print('layer name: {} with similarity ratio {:.5f}'.format(key, value))

    for name in del_name:
        del internal_attr_change[name]

    layer_name = sorted(internal_attr_change, key=internal_attr_change.get, reverse=True)

    return layer_name
