import yaml
import os

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

from . import architectures
from .sensor_failure_sim import bias_mask_T, mean_mask_T, performance_degradation_mask_T, scaling_mask_T

# Testing interface

def evaluate_models(mask_names, dataloader, bias_mult, noise_l, scaling_mult, pretrain_tasks, feat_dim, mean, std, device="mps"):
	df = pd.DataFrame()
	for mask_name in mask_names:
		print(mask_name)
		models = {}
		for task in pretrain_tasks:
			args = yaml.load(open(task, "r"), Loader=yaml.FullLoader)
			model = load_model(args, device, epoch=50)
			models[task.split("/")[-1].split(".")[0]] = model

		res_df = eval_robustness(args["data"]["columns_to_standardize"], feat_dim=feat_dim, mask_name=mask_name, models=models, dataloader=dataloader, mean_tensor=mean, std_tensor=std, criterion=nn.MSELoss(), device=device, bias_mult=bias_mult, noise_l= noise_l, scaling_mult = scaling_mult)
		res_df["error"] = [mask_name] * len(res_df)
		df = pd.concat([df, res_df], axis=0)
	return df

# Model loading

def load_model(args, device, epoch = 50):
	# insantiate model with pretrain config model parameters
	path_to_pretrain_conf = os.path.join("../", args["pretrained_backbone"]["pretrain_config_path"])
	pt_config = yaml.load(open(path_to_pretrain_conf, "r"), Loader=yaml.FullLoader)
	model_params = pt_config["backbone"]["model_params"]
	model_params["seed"] = pt_config["experiment"]["seed"]
	model = architectures.__dict__[pt_config["backbone"]["model_name"]](**model_params).to(device)

	# instantiate finetune head with respective defined parameters
	model.output_net = architectures.__dict__[args["head"]["model_name"]](**args["head"]["model_params"]).to(device)
	# load the trained model
	checkpoint_path = os.path.join(f"../{args['experiment']['folder_path']}", f"{args['experiment']['name']}/model_checkpoints/pretrain_{epoch}_epochs")

	checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) 
	model.load_state_dict(checkpoint['model_state_dict'])
	return model

# Evaluate robustness

def eval_robustness(feat_names, feat_dim, mask_name, models, dataloader, mean_tensor, std_tensor, criterion, device, bias_mult = 2, noise_l = 0.4, scaling_mult = 3):
	
	results_df = pd.DataFrame()
	loss_dict = {}
	loss_dict_no = {}

	for name, model in models.items():
		preds, labels, loss = run_test(model, dataloader, criterion, device, mean_tensor, std_tensor)
		# bootstrap no failure performance
		bootstrap_mean, ci_lower, ci_upper = bootstrap_testset(preds, labels, iterations=200, seed = 42)
		loss_dict_no[name + "_mse_bootstrapped"] = bootstrap_mean
		loss_dict_no[name + "_ci_lower"] = ci_lower
		loss_dict_no[name + "_ci_upper"] = ci_upper
		loss_dict_no[name] = loss

	for feature_to_mask in range(feat_dim):
		
		if mask_name == "mean":
			mask = mean_mask_T(mean =  mean_tensor[feature_to_mask])
		elif mask_name == "bias":
			mask = bias_mask_T(bias = std_tensor[feature_to_mask]*bias_mult)
		elif mask_name == "noise":
			mask = performance_degradation_mask_T(std = std_tensor[feature_to_mask], noise_level=noise_l)
		elif mask_name == "scaling":
			mask = scaling_mask_T(mult = scaling_mult)

		for name, model in models.items():
			preds, labels, loss = run_test(model, dataloader, criterion, device, mean = mean_tensor, std = std_tensor, mask = mask, feature_to_mask = feature_to_mask)
			bootstrap_mean, ci_lower, ci_upper = bootstrap_testset(preds, labels, iterations=200, seed = 42)

			loss_dict[name] = loss
			loss_dict[name + "_mse_bootstrapped"] = bootstrap_mean
			loss_dict[name + "_ci_lower"] = ci_lower
			loss_dict[name + "_ci_upper"] = ci_upper

		df = pd.DataFrame(loss_dict, index=[feat_names[feature_to_mask]])
		results_df = pd.concat([results_df, df], axis = 0)

	no_failure_df = pd.DataFrame(loss_dict_no, index=["no_failure"])
	results_df_plot = pd.concat([results_df, no_failure_df], axis = 0)

	return results_df_plot

# Evaluate robustness for one model

def run_test(model, dataloader, criterion, device, mean = None, std = None, mask = None, feature_to_mask = None):

	labels = []
	preds = []
	loss = 0

	model.eval()
	with torch.no_grad():
		for x, y in dataloader:
			x, y = x.to(device), y.to(device)
			
			if mean is not None and std is not None:
				if mask is not None:
					x = mask.apply_mask(x, feature_to_mask)
				X = torch.clone(x)
				X[:,:,:] = torch.Tensor((x[:,:,:].cpu() - mean) / std).to(device)
			else:
				X = x
			
			pred = model(X)[:,-1]

			preds += np.array(pred.cpu()).tolist()
			labels += np.array(y.cpu()).tolist()

		loss = criterion(torch.tensor(preds), torch.tensor(labels)).item()

	return preds, labels, loss

# Bootstrap testset

def bootstrap_testset(preds, labels, iterations = 200, seed = 42):
	preds = torch.tensor(preds)
	labels = torch.tensor(labels)
	
	rng = np.random.RandomState(seed=seed)
	idx = np.arange(labels.shape[0]) 
	mse = torch.nn.MSELoss()

	test_mses = []
	for i in range(iterations):
		pred_idx = rng.choice(idx, size=idx.shape[0], replace=True)
		mse_test_boot = mse(preds[pred_idx], labels[pred_idx])
		
		test_mses.append(mse_test_boot)
	bootstrap_mean = np.mean(test_mses)

	ci_lower = np.percentile(test_mses, 2.5)
	ci_upper = np.percentile(test_mses, 97.5)
	return bootstrap_mean, ci_lower, ci_upper



