import numpy as np
from numpy.core.numeric import roll 

from fancyimpute import KNN, IterativeImputer, MatrixFactorization

import torch
import torch.optim as optim

from neuralfaults.utils.parser import get_parser_with_args
from neuralfaults.utils.helpers import (initialize_metrics, 
                                        get_mean_metrics, 
                                        compute_metrics, get_loss_function)
from neuralfaults.utils.runner import Runner
from neuralfaults.utils.dataloader import get_dataloaders

parser = get_parser_with_args()
opt = parser.parse_args()

if len(opt.fail_quants):
    fail_quants = opt.fail_quants.split(",")
    inp_quants = opt.inp_quants.split(",")
    fail_quants_prob = list(map(float, opt.fail_quants_prob.split(",")))

    for fq in fail_quants:
        assert fq in inp_quants

    assert len(fail_quants) <= len(inp_quants)
    assert len(fail_quants_prob) == len(fail_quants)

    for fqp in fail_quants_prob:
        assert fqp >= 0.0 and fqp <= 1.0

train_loader, val_loader, metadata = get_dataloaders(opt)

assert len(opt.weight_path) > 0

model = torch.load(opt.weight_path)
model = model.cuda()
model.eval()

if 'sgan' in opt.impute_model or 'e2e' in opt.impute_model or 'gan' in opt.impute_model or 'gain' in opt.impute_model or 'mrnn' in opt.impute_model or 'grud' in opt.impute_model or 'brits' in opt.impute_model:
    model_g = torch.load(opt.impute_weight_path)
    model_g = model_g.cuda()
    model_g.eval()

criterion = get_loss_function(opt)

metrics = initialize_metrics(opt)

imputations_losses = []

for batch in  val_loader:
    if len(opt.fail_quants):
        input_tensor, output_tensor, mask_tensor, delta_tensor = batch
        original_tensor = input_tensor.clone()
        mask_tensor = mask_tensor.bool()
        if opt.impute_model == 'mean':
            for i in range(len(opt.inp_quants.split(","))):
                quant = input_tensor[:, i, :]
                mask = mask_tensor[:, i, :]
                quant[mask] = (( metadata['mean'][opt.inp_quants.split(",")[i]] -
                                metadata['mean'][opt.inp_quants.split(",")[i]] ) / 
                                metadata['std'][opt.inp_quants.split(",")[i]])
                input_tensor[:, i, :] = quant
        if opt.impute_model == 'zero':
            for i in range(len(opt.inp_quants.split(","))):
                quant = input_tensor[:, i, :]
                mask = mask_tensor[:, i, :]
                quant[mask] = (( 0 -
                                metadata['mean'][opt.inp_quants.split(",")[i]] ) / 
                                metadata['std'][opt.inp_quants.split(",")[i]])
                input_tensor[:, i, :] = quant
        # if opt.impute_model == 'last':
        #     for i in range(input_tensor.shape[2]):
        #         roll_tensor = torch.roll(input_tensor, shifts=i+1, dims=2)
        #         input_tensor[mask_tensor] = roll_tensor[mask_tensor]
        if 'knn' in opt.impute_model:
            for i in range(len(opt.inp_quants.split(","))):
                quant_npy = input_tensor[:, i, :].data.numpy()
                mask_npy = mask_tensor[:, i, :].data.numpy()
                quant_npy[mask_npy] = np.NaN
                k = 3
                if '#' in opt.impute_model:
                    k = int(opt.impute_model.split('#')[1])
                filled_npy = KNN(k=k, verbose=False).fit_transform(quant_npy)
                input_tensor[:, i, :] = torch.tensor(filled_npy)
        if 'mf' in opt.impute_model:
            for i in range(len(opt.inp_quants.split(","))):
                quant_npy = input_tensor[:, i, :].data.numpy()
                mask_npy = mask_tensor[:, i, :].data.numpy()
                quant_npy[mask_npy] = np.NaN
                filled_npy = MatrixFactorization(epochs=100, learning_rate=0.001, verbose=False).fit_transform(quant_npy)
                input_tensor[:, i, :] = torch.tensor(filled_npy)
        if 'mice' in opt.impute_model:
            for i in range(len(opt.inp_quants.split(","))):
                quant_npy = input_tensor[:, i, :].data.numpy()
                mask_npy = mask_tensor[:, i, :].data.numpy()
                quant_npy[mask_npy] = np.NaN
                filled_npy = IterativeImputer(max_iter=100, verbose=0).fit_transform(quant_npy)
                input_tensor[:, i, :] = torch.tensor(filled_npy)
        if 'sgan' in opt.impute_model:
            input_tensor = input_tensor.cuda().float()
            mask_tensor = mask_tensor.cuda().float()
            original_tensor = original_tensor.cuda().float()
            zero_tensor = torch.zeros(*input_tensor.shape).cuda().float()
            imputed_tesnor, err_tensor = model_g(input_tensor, zero_tensor, (1 - mask_tensor), zero_tensor)
        else:
            input_tensor = input_tensor.cuda().float()
            mask_tensor = mask_tensor.cuda().float()
            original_tensor = original_tensor.cuda().float()
            zero_tensor = torch.zeros(*input_tensor.shape).cuda().float()
            imputed_tesnor = model_g(input_tensor, zero_tensor, (1 - mask_tensor), zero_tensor)
        imputation_loss = ((input_tensor * mask_tensor - original_tensor * mask_tensor)**2).sum() / mask_tensor.sum()
        imputations_losses.append(imputation_loss.item())
    else:
        input_tensor, output_tensor = batch

    input_tensor = input_tensor.cuda().float()
    output_tensor = output_tensor.cuda().float()

    prediction_tensor = model(input_tensor)
    loss = criterion(prediction_tensor, output_tensor)
    compute_metrics(metrics, loss, prediction_tensor,
                    output_tensor, opt)

                    
print ('Imputation MSE :', np.mean(imputations_losses))
print (get_mean_metrics(metrics))