import os

import numpy as np
import torch
from tqdm import tqdm
from torch.nn import functional as F
import pickle
from vision_datasets.registry import get_dataset
from vision_datasets.registry import split_train_into_train_val
from heads import get_classification_head
from vision_datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle

# export PYTHONPATH="$PYTHONPATH:$PWD"

import time
import sys
root = '/data/common/task-arithmetic'
sys.path.append(root)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

from eval import eval_single_dataset
from args import parse_arguments


def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
# exam_datasets = ["SUN397"] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD

model = 'ViT-B-32'
args = parse_arguments()
args.data_location = root + '/data'
args.model = model
args.save = root + '/task_vectors_checkpoints/' + model
args.logs_path = '../logs/' + model
pretrained_checkpoint = root+'/task_vectors_checkpoints/'+model+'/zeroshot.pt'

str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
log = create_log_dir(args.logs_path, 'log_fisher_{}.txt'.format(str_time_))
fisher_variant = 'hard'

final_model = torch.load(pretrained_checkpoint)
pretrained_model = torch.load(pretrained_checkpoint)
pretrained_model_dic = pretrained_model.state_dict()

trainable_params = {}
frozen = ["model.positional_embedding", "model.text_projection", "model.logit_scale", "model.token_embedding.weight", "model.ln_final.weight", "model.ln_final.bias"]
for k, v in pretrained_model_dic.items():
    if k not in frozen:
        trainable_params[k] = v

def compute_fisher(train_dataloader, finetuned_model, classifier_head, fisher_variant='hard'):
    model = finetuned_model.to(device)
    model.train()
    classifier_head.to(device)
    fisher = {}
    n_b = 0

    n_step = total = len(train_dataloader)

    for step, data in tqdm(
        enumerate(train_dataloader), total=total, desc="Computing fisher"
    ):
        if n_step > 0 and step == n_step:
            break

        data = maybe_dictionarize(data)
        x = data['images'].to(device)
        features = model(x)
        logits = classifier_head(features)
        n_b += 1
        # computer empirical fisher

        if fisher_variant == "hard":
            log_probs = torch.log_softmax(logits, -1)
            _, target_labels = logits.max(-1)
            nll_loss = F.nll_loss(log_probs, target_labels)
            model.zero_grad()
            nll_loss.backward()
            b_n2fisher = collect_squared_gradients(model)
        elif fisher_variant == "soft":
            probs = torch.softmax(logits, -1).detach()  # [b,c]
            log_probs = torch.log_softmax(logits, -1)
            num_labels = probs.size(-1)
            nll_losses = []
            for label in range(num_labels):
                target = (
                    torch.full(probs.size()[:-1], label).long().to(probs.device)
                )
                nll_loss = F.nll_loss(log_probs, target, reduction="none")
                nll_losses.append(nll_loss)
            nll_losses = torch.stack(nll_losses, -1)  # [b,c]
            weighted_nll_losses = probs * nll_losses
            mean_nll_loss = weighted_nll_losses.sum(-1).mean()
            model.zero_grad()
            mean_nll_loss.backward()
            b_n2fisher = collect_squared_gradients(model)

        for n, f in b_n2fisher.items():
            if n not in fisher:
                fisher[n] = f
            else:
                fisher[n] += f
    assert n_b
    for n, f in fisher.items():
        fisher[n] = f / n_b
    return fisher


def collect_squared_gradients(model):
    n2fisher = {}
    for n, p in model.named_parameters():
        if p.grad is not None:
            n2fisher[n] = p.grad.detach() ** 2
    return n2fisher


def fisher_weighted_average(all_params, fisher_weights):
    model_coeffs = torch.ones(len(exam_datasets)) * 0.3
    avg_params = {}

    for n, params in all_params.items():
        params = torch.stack(params)  # [N, *]
        fisher = (
            torch.stack([x[n] for x in fisher_weights])
            + 1.0e-10 #self.merger_config.fisher_smooth
        )  # [N, *]

        coeff = model_coeffs.view(-1, *[1 for _ in range(params.dim() - 1)]).to(
            params.device
        )

        sum_p = params * fisher * coeff
        sum_p = sum_p.sum(0)

        denom = (fisher * coeff).sum(0)

        avg_p = sum_p / denom
        avg_params[n] = avg_p

    return avg_params


fisher_weights, finetuned_models = [], []
for dataset_name in exam_datasets:
    finetuned_checkpoint = root+'/task_vectors_checkpoints/'+model+'/'+dataset_name+'/finetuned.pt'
    try:
        finetuned_model = torch.load(finetuned_checkpoint)
    except:
        finetuned_model = pickle.load(open(finetuned_checkpoint, 'rb'))

    base_dataset = get_dataset(dataset_name, final_model.val_preprocess, location=args.data_location, batch_size=args.batch_size)
    dataset = split_train_into_train_val(
                base_dataset, dataset_name, args.batch_size, num_workers=2, val_fraction=0.1, max_val_samples=5000)
    valset = dataset.test_dataset

    val_dataloader = torch.utils.data.DataLoader(
                valset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=2
            )

    classifier_head = get_classification_head(args, dataset_name)
    fisher = compute_fisher(val_dataloader, finetuned_model, classifier_head, fisher_variant=fisher_variant)

    fisher_weights.append(fisher)
    finetuned_models.append(finetuned_model.state_dict())


params = {}
for local_model in finetuned_models:
    n2p = {k: v for k,v in local_model.items()}
    merge_param_names = []
    frozen = ["model.positional_embedding", "model.text_projection", "model.logit_scale", "model.token_embedding.weight", "model.ln_final.weight", "model.ln_final.bias"]
    for n in n2p:
        if n not in frozen:
            merge_param_names.append(n)
    for n in merge_param_names:
        if n not in params:
            params[n] = []
        params[n].append(n2p[n])

image_encoder = torch.load(pretrained_checkpoint)
image_encoder.load_state_dict(fisher_weighted_average(params, fisher_weights), strict=False)
image_encoder.save(root+'/merged_models/' + model + f'/fisher_0.3_{fisher_variant}.pt')

accs = []
for dataset in exam_datasets:
    metrics = eval_single_dataset(image_encoder, dataset, args)
    log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%')
    accs.append(metrics.get('top1')*100)
log.info('Avg ACC:' + str(np.mean(accs)) + '%')