import json
import os
import sys

from munch import Munch
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import yaml

import models
from samplers import get_data_sampler, sample_transformation
from tasks import get_task_sampler

NLR_STDS = [0.1, 0.25, 0.5, 1.0]


def get_augmented_relevant_tasks(task_name):
    if task_name in ["linear_regression", "linear_classification"]:
        return ["linear_regression", "linear_classification"]
    elif "noisy_linear_regression" in task_name:
        return ["linear_regression", task_name]
    # defaults to returning the task itself
    return [task_name]


def get_model_from_run(run_path, step=-1, only_conf=False):
    config_path = os.path.join(run_path, "config.yaml")
    try:
        with open(config_path) as fp:  # we don't Quinfig it to avoid inherits
            conf = Munch.fromDict(yaml.safe_load(fp))
        if only_conf:
            return None, conf
    except:
        return None, None

    model = models.build_model(conf.model)

    if step == -1:
        state_path = os.path.join(run_path, "state.pt")
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
    else:
        model_path = os.path.join(run_path, f"model_{step}.pt")
        state_dict = torch.load(model_path)
        model.load_state_dict(state_dict)

    return model, conf


# Functions for evaluation

def eval_batch(model, task_sampler, xs, xs_p=None, U=None, n_dims_truncated=None):
    task = task_sampler()
    if torch.cuda.is_available() and model.name.split("_")[0] in ["gpt2", "lstm", "EncoderTF"]:
        device = "cuda"
    else:
        device = "cpu"

    if xs_p is None:
        if U is not None:
            ys = task.evaluate(xs, U, n_dims_truncated)

        pred = model(xs.to(device), ys.to(device)).detach()
        metrics = task.get_metric()(pred.cpu(), ys)
 
    return metrics



def eval_coef(model, task_sampler, xzs, U=None, n_dims_truncated=None, delta=1):
    task = task_sampler()
    if torch.cuda.is_available() and model.name.split("_")[0] in ["gpt2", "lstm", "EncoderTF"]:
        device = "cuda"
    else:
        device = "cpu"

    p = int(1/3 * xzs.shape[2])
    
    coef = torch.zeros(xzs.shape[0], xzs.shape[1], p)
    if U is not None:
        ys, beta = task.evaluate(xzs, U, n_dims_truncated, return_beta=True)
        beta = beta[:, :, 0].unsqueeze(1)
        # Compute the mean of the training elements along the second dimension
        mean_values = xzs[:, :-1, :].mean(dim=1, keepdim=True)
        # Fill the prediction position with the computed mean values
        xzs[:, -1, :] = mean_values.squeeze(1)

        for i in range(p):
            xzs_eval = xzs.clone()  # Clone to avoid modifying the original xzs
            
            xzs_eval[:, -1, i] += delta
        # Evaluate the model at the original and perturbed points
        
            pred0 = model(xzs.to(device), ys.to(device), inds = [ys.shape[1]-1]).detach()
            pred1 = model(xzs_eval.to(device), ys.to(device), inds = [ys.shape[1]-1]).detach()
            # Note that pred0, pred1 are round to 4 digits
            coef[:, :, i] = ((pred1.cpu() - pred0.cpu()) / delta)
        err = torch.norm(coef - beta, dim=2) ** 2
        
    else:
        raise ValueError("U is None, please provide U for IV regression.")
    return err      

# Functions for generating different kinds of train/test data

def gen_standard(data_sampler, n_points, b_size):
    xs = data_sampler.sample_xs(n_points, b_size)

    return xs, None

def gen_standard_iv(data_sampler, n_points, b_size, U=None, scale_Theta=1, dist_shift=True, n_dims_truncated=None):
    xzs = data_sampler.sample_xzs(n_points, b_size, U, scale_Theta, dist_shift, n_dims_truncated) 
    return xzs, None

def gen_opposite_quadrants(data_sampler, n_points, b_size):
    xs = data_sampler.sample_xs(n_points, b_size)
    pattern = torch.randn([b_size, 1, xs.shape[2]]).sign()

    xs_train_pre = xs.abs() * pattern
    xs_test_post = -xs_train_pre

    return xs_train_pre, xs_test_post


def gen_random_quadrants(data_sampler, n_points, b_size):
    xs = data_sampler.sample_xs(n_points, b_size)
    pattern = torch.randn([b_size, 1, xs.shape[2]]).sign()

    xs_train_pre = xs.abs() * pattern
    xs_test_post = xs

    return xs_train_pre, xs_test_post


def gen_orthogonal_train_test(data_sampler, n_points, b_size):
    xs = data_sampler.sample_xs(n_points, b_size)
    n_dim = xs.shape[2]
    n_points = min(n_points, n_dim)
    # raise ValueError("number of points should be at most the dimension.")
    xs_train_pre = xs
    xs_test_post = torch.zeros(xs.shape)
    for i in range(n_points):
        xs_test_post_i = xs[:, i : i + 1, :]
        xs_train_pre_i = xs[:, :i, :]
        _, _, Vt = torch.linalg.svd(xs_train_pre_i, full_matrices=False)
        xs_train_pre_i_projection = Vt.transpose(1, 2) @ Vt
        xs_test_post_i_orthogonalized = (
            xs_test_post_i - xs_test_post_i @ xs_train_pre_i_projection
        )
        xs_test_post_i_normalized = (
            xs_test_post_i_orthogonalized
            * xs_test_post_i.norm(dim=2).unsqueeze(2)
            / xs_test_post_i_orthogonalized.norm(dim=2).unsqueeze(2)
        )

        xs_test_post[:, i : i + 1, :] = xs_test_post_i_normalized

    return xs_train_pre, xs_test_post


def gen_overlapping_train_test(data_sampler, n_points, b_size):
    xs = data_sampler.sample_xs(n_points, b_size)
    xs_train_pre = xs
    xs_test_post = xs.clone()
    b_size = xs.shape[0]
    for i in range(1, n_points):
        xs_train_pre_i = xs[:, :i, :]
        perm = torch.stack([torch.randperm(i) for _ in range(b_size)]).unsqueeze(dim=1)
        ind_mat = (perm == 0) + 0.0
        xs_test_post[:, i : i + 1, :] = ind_mat @ xs_train_pre_i

    return xs_train_pre, xs_test_post


def aggregate_metrics(metrics, bootstrap_trials=1000):
    """
    Takes as input a tensor of shape (num_eval, n_points) and returns a dict with
    per-point mean, stddev, and bootstrap limits
    """
    results = {}
    results["mean"] = metrics.mean(dim=0)
    results["std"] = metrics.std(dim=0, unbiased=True)
    n = len(metrics)
    bootstrap_indices = torch.randint(n, size=(bootstrap_trials, n))
    bootstrap_means = metrics[bootstrap_indices].mean(dim=1).sort(dim=0)[0]
    results["bootstrap_low"] = bootstrap_means[int(0.05 * bootstrap_trials), :]
    results["bootstrap_high"] = bootstrap_means[int(0.95 * bootstrap_trials), :]

    return {k: v.tolist() for k, v in results.items()}


def eval_model(
    model,
    task_name,
    data_name,
    n_dims,
    n_points,
    prompting_strategy,
    num_eval_examples=6400,
    batch_size=64,
    data_sampler_kwargs={},
    task_sampler_kwargs={},
    return_coef_mse=True,
):
    """
    Evaluate a model on a task with a variety of strategies.
       Args:
       - task: which base task we are evaluating on. E.g., "linear_regression"
       - prompting_strategy: how to construct the prompt, e.g., "random_quadrants"
       - num_eval_examples: total number of examples to evaluate on
       - **sampler_kwargs: remaining arguments to pass directly to the sampler
    """

    assert num_eval_examples % batch_size == 0
    data_sampler = get_data_sampler(data_name, n_dims, **data_sampler_kwargs)
    task_sampler = get_task_sampler(
        task_name, n_dims, batch_size, **task_sampler_kwargs
    )

    # pdb.set_trace()

    all_metrics = []
    all_coef_mse = []
    global p
    p = int(1/3 * n_dims)
    generating_func = globals()[f"gen_{prompting_strategy}"]
    for i in tqdm(range(num_eval_examples // batch_size)):
        if "iv" in task_name:
            U = torch.randn(batch_size, n_points, p)
            xzs, xs_p = generating_func(data_sampler, n_points, batch_size, U)
           
            metrics = eval_batch(model, task_sampler, xzs, xs_p, U)
     
        
        all_metrics.append(metrics)

    metrics = torch.cat(all_metrics, dim=0)
    if return_coef_mse:
        return aggregate_metrics(metrics)
    else:
        return aggregate_metrics(metrics)


def build_evals(conf, all_task_names=None, num_eval_examples=None):
    n_dims = conf.model.n_dims
    n_points = conf.training.curriculum.points.end
    batch_size = conf.training.batch_size

    if all_task_names is None:
        if "tasks" in conf.training.keys():
            all_task_names = []
            for t in conf.training.tasks:
                all_task_names += get_augmented_relevant_tasks(t.name)
            # deduplicate
            all_task_names = list(set(all_task_names))
        elif "task" in conf.training.keys():
            # backward compatible
            all_task_names = [conf.training.task]
        else:
            raise NotImplementedError

    all_noise_stds = []
    if "noisy_linear_regression" in all_task_names:
        for t in conf.training.tasks:
            if t.name == "noisy_linear_regression":
                all_noise_stds.append(t.kwargs["noise_std"])
    for noise_std in [0.1, 0.5]:
        if noise_std not in all_noise_stds:
            all_noise_stds.append(noise_std)

    data_name = conf.training.data

    base_kwargs = {
        "n_dims": n_dims,
        "n_points": n_points,
        "batch_size": batch_size,
        "data_name": data_name,
        "prompting_strategy": "standard",
    }

    if num_eval_examples is not None:
        base_kwargs["num_eval_examples"] = num_eval_examples

    evaluation_kwargs = {}

    for task_name in all_task_names:
        name = task_name

        if "noisy_linear_regression" in name:
            for noise_std in all_noise_stds:
                nlr_name = f"noisy_linear_regression_noise={noise_std}"
                evaluation_kwargs[nlr_name] = base_kwargs.copy()
                evaluation_kwargs[nlr_name]["task_name"] = "noisy_linear_regression"
                evaluation_kwargs[nlr_name]["task_sampler_kwargs"] = {
                    "normalize_w": True, "noise_std": noise_std
                }
        else:
            evaluation_kwargs[name] = base_kwargs.copy()
            evaluation_kwargs[name]["task_name"] = task_name
            if name in ["iv_regression"]:
                evaluation_kwargs[name]["prompting_strategy"] = "standard_iv"
                evaluation_kwargs[name]["task_sampler_kwargs"] = {
                    "normalize_w": True,
                }
            if name in ["linear_regression", "sparse_linear_regression"]:
                evaluation_kwargs[name]["task_sampler_kwargs"] = {
                    "normalize_w": True,
                }

    return evaluation_kwargs


def compute_evals(all_models, evaluation_kwargs, save_path=None, recompute=False, eval_coef = True):
    try:
        with open(save_path) as fp:
            all_metrics = json.load(fp)
    except Exception:
        all_metrics = {}
        
    for eval_name, kwargs in tqdm(evaluation_kwargs.items()):

        metrics = {}
        if eval_name in all_metrics and not recompute:
            metrics = all_metrics[eval_name]
           
        for model in all_models:
            if model.name in metrics and not recompute:
                continue
        
            metrics[model.name] = eval_model(model, **kwargs, return_coef_mse=True)

        all_metrics[eval_name] = metrics
        
    if save_path is not None:
        with open(save_path, "w") as fp:
            json.dump(all_metrics, fp, indent=2)

    return all_metrics


def get_run_metrics(
    run_path, step=-1, cache=True, skip_model_load=False, skip_baselines=False,
    all_task_names=None, recompute=False, additional_ridge_lams=None,
    num_eval_examples=None, eval_type="eval_pred"
):
    if skip_model_load:
        _, conf = get_model_from_run(run_path, only_conf=True)
        all_models = []
    else:
        model, conf = get_model_from_run(run_path, step)
        model = model.cuda().eval()
        all_models = [model]

        # deduplicate tasks if there are multiple of same category (e.g. noisy linear regression)
        tasks_covered = []
        if not skip_baselines:
            if "tasks" in conf.training.keys():
                for task in conf.training.tasks:
                    if task.name not in tasks_covered:
                        all_models += models.get_relevant_baselines(
                            task.name
                        )
                        tasks_covered.append(task.name)
            
            elif "task" in conf.training.keys():
                # backward compatible
                all_models += models.get_relevant_baselines(conf.training.task)
            else:
                raise NotImplementedError

    evaluation_kwargs = build_evals(
        conf,
        all_task_names=all_task_names, num_eval_examples=num_eval_examples
    )

    print(evaluation_kwargs)

    if not cache:
        save_path = None
    elif step == -1:
        save_path = os.path.join(run_path, "metrics.json")
    else:
        save_path = os.path.join(run_path, f"metrics_{step}.json")

    if not recompute:
        if save_path is not None and os.path.exists(save_path):
            checkpoint_created = os.path.getmtime(run_path)
            cache_created = os.path.getmtime(save_path)
            if checkpoint_created > cache_created:
                pass

    print(f"recompute is {recompute} for run {run_path}")
    all_metrics = compute_evals(all_models, evaluation_kwargs, save_path, recompute) 
    ##note: pay attention to evaluation_kwargs
    return all_metrics


def conf_to_model_name(conf):
    if conf.model.family == "gpt2":
        return {
            (3, 2): "Transformer-xs",
            (6, 4): "Transformer-small",
            (12, 8): "Transformer",
        }[(conf.model.n_layer, conf.model.n_head)]
    elif conf.model.family == "EncoderTF":
        return "EncoderTF"
    else:
        return conf.wandb.name


def baseline_names(name):
    if "OLS" in name:
        return "Least Squares"
    if name == "averaging":
        return "Averaging"
    if "NN" in name:
        k = name.split("_")[1].split("=")[1]
        return f"{k}-Nearest Neighbors"
    if "lasso" in name:
        alpha = name.split("_")[1].split("=")[1]
        return f"Lasso (alpha={alpha})"
    if "gd" in name:
        return "2-layer NN, GD"
    if "decision_tree" in name:
        return "Greedy Tree Learning"
    if "xgboost" in name:
        return "XGBoost"
    return name


def read_run_dir(run_dir):
    all_runs = {}
    for task in os.listdir(run_dir):
        task_dir = os.path.join(run_dir, task)
        for run_id in os.listdir(task_dir):
            run_path = os.path.join(task_dir, run_id)
            _, conf = get_model_from_run(run_path, only_conf=True)
            if conf is None:
                continue
            params = {}
            params["run_id"] = run_id
            params["task"] = task
            params["model"] = conf_to_model_name(conf)

            num_tasks = (
                conf.training.num_tasks if "num_tasks" in conf.training else None
            )
            params["num_tasks"] = num_tasks if num_tasks is not None else -1
            num_examples = (
                conf.training.num_training_examples
                if "num_training_examples" in conf.training
                else None
            )
            params["num_examples"] = num_examples if num_examples is not None else -1
            params["n_dims"] = conf.model.n_dims
            params["n_layer"] = conf.model.n_layer
            params["n_head"] = conf.model.n_head
            params["run_name"] = conf.wandb.name

            for k, v in params.items():
                if k not in all_runs:
                    all_runs[k] = []
                all_runs[k].append(v)

    df = pd.DataFrame(all_runs).sort_values("run_name")
    # assert len(df) == len(df.run_name.unique())
    return df


if __name__ == "__main__":
    run_dir = sys.argv[1]
    for task in os.listdir(run_dir):
        task_dir = os.path.join(run_dir, task)
        print(f"Evaluating task {task}")
        for run_id in tqdm(os.listdir(task_dir)):
            run_path = os.path.join(run_dir, task, run_id)
            metrics = get_run_metrics(run_path)
