import os
import numpy as np
import torch
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
import wandb

base_path = os.path.dirname(os.path.dirname(__file__))

def inner_train(mnet, hnet, train_dataloader, optimizer, inner_epochs, device, stop_loss= None, learnt_embedding=None):
    for epoch in range(inner_epochs):
        total_loss=[]
        for batch_id, sample in enumerate(train_dataloader):
            sample_image_features = sample["image_features"]
            sample_text_features = sample["text_features"]
            sample_ques_emb = sample["ques_emb"]
            labels = sample["label"].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            if hnet is not None:
                if learnt_embedding is None:
                    weights = hnet.forward(uncond_input=sample_ques_emb[0])
                else:
                    weights = hnet.forward(uncond_input=learnt_embedding)
                outputs = mnet(sample_image_features, sample_text_features, weights=weights)
            else:
                outputs = mnet(sample_image_features, sample_text_features)

            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss.append(loss.item())

        if epoch == 0:
            start_loss = np.mean(total_loss)
        if stop_loss is not None and np.mean(total_loss) < stop_loss:
            break

    end_loss = np.mean(total_loss)
    return start_loss, end_loss

def meta_update_step(config, model, meta_epoch, weights_before, weights_after):
    outerstepsize = config["meta_stepsize_final"] * (meta_epoch / config["meta_epochs"]) + \
                            config["meta_stepsize_start"] * (1 - meta_epoch / config["meta_epochs"])
    model.load_state_dict({name : weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize 
                                for name in weights_before})

def get_dataset_logits(mnet, hnet, dataloader, learnt_embedding=None):
    logits=[]
    for sample in dataloader:
        sample_image_features = sample["image_features"]
        sample_text_features = sample["text_features"]
        sample_ques_emb = sample["ques_emb"]

        if hnet is not None:
            if learnt_embedding is None:
                weights = hnet.forward(uncond_input=sample_ques_emb[0])
            else:
                weights = hnet.forward(uncond_input=learnt_embedding)
            similarity = mnet(sample_image_features, sample_text_features, weights=weights)
        else:
            similarity = mnet(sample_image_features, sample_text_features)
        logits.append(similarity)
    return torch.cat(logits, dim=0)

def test_accuracy(mnet, hnet, dataset, dataloader, learnt_embedding=None):
    # Validation inner-loop testing
    mnet.eval()
    if hnet is not None:
        hnet.eval()

    with torch.no_grad():
        y_pred = []
        y_true = []
        for sample in dataloader:
            sample_image_features = sample["image_features"]
            sample_text_features = sample["text_features"]
            sample_ques_emb = sample["ques_emb"]
            labels = sample["label"]

            if hnet is not None:
                if learnt_embedding is None:
                    weights = hnet.forward(uncond_input=sample_ques_emb[0])
                else:
                    weights = hnet.forward(uncond_input=learnt_embedding)
                similarity = mnet(sample_image_features, sample_text_features, weights=weights)
            else:
                similarity = mnet(sample_image_features, sample_text_features)

            _, indices = similarity.topk(1)
            y_pred += list(np.squeeze(indices.cpu().numpy(), axis=1))
            y_true += list(labels)

    acc = accuracy_score(y_true, y_pred)
    mnet.train()
    if hnet is not None:
        hnet.train()
    return y_pred, y_true, acc

