import copy
from pathlib import Path

import einx
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import torch
import xarray as xr
from loguru import logger

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.config.config_run import ConfigRun
from tabicl.core.enums import DatasetSize, DownstreamTask, SearchType, Task
from tabicl.core.get_model import get_model
from tabicl.core.get_trainer import get_trainer
from tabicl.core.run_experiment import Data
from tabicl.data.dataset_openml import OpenMLDataset
from tabicl.sweeps.hyperparameter_drawer import HyperparameterDrawer
from tabicl.utils.paths_and_filenames import PATH_TO_OPENML_DATASETS


def decision_boundary_analysis(cfg: ConfigPretrain, path_weights: Path, task: DownstreamTask):

    output_dir = cfg.output_dir / f'decision_boundary'
    output_dir.mkdir(exist_ok=True, parents=True)

    logger.info(f"Start decision boundary analysis for {task.value}")

    hyperparams = prepare_hyperparams(cfg, task, path_weights)

    cfg_run = ConfigRun.create(
        output_dir = output_dir,
        device = cfg.devices[0],
        cpus = None,
        model_name = cfg.model_name,
        seed = 0,
        task = Task.CLASSIFICATION,
        dataset_size = DatasetSize.MEDIUM,
        datafile_path=Path(PATH_TO_OPENML_DATASETS) / "whytrees_44156_MEDIUM.nc",
        hyperparams = hyperparams
    )
    torch.cuda.set_device(cfg_run.device)

    dataset = OpenMLDataset(cfg_run.datafile_path, cfg_run.task)

    da = make_boundary_data(cfg_run, dataset, f"{cfg.model_name.value} {task.value}", cfg.testing.decision_boundary_analysis_grid_size)
    da.to_netcdf(output_dir / f"decision_boundary_{task.value}.nc")

    fig = make_boundary_plot(da)
    fig.savefig(output_dir / f"decision_boundary_{task.value}.png")

    logger.info(f"Finished decision boundary analysis for {task.value}")
    logger.info(f"Results saved in {output_dir}")



def prepare_hyperparams(cfg: ConfigPretrain, task: DownstreamTask, path_weights: Path):

    hyperparams = copy.deepcopy(cfg.hyperparams_finetuning)
    hyperparams = HyperparameterDrawer(hyperparams).draw_config(SearchType.DEFAULT)

    match task:
        case DownstreamTask.FINETUNE:
            # Don't change the number of epochs
            pass
        case DownstreamTask.ZEROSHOT:
            hyperparams['max_epochs'] = 0
        case _:
            raise ValueError(f"Task {task} not supported for decision boundary analysis")
        
    hyperparams['use_pretrained_weights'] = True
    hyperparams['path_to_weights'] = path_weights

    return hyperparams


def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap



def irregularity_value(pred_grid):

    pred = pred_grid[:, :, 0]

    pred_middle = pred[1:-1, 1:-1]
    pred_up = pred[:-2, 1:-1]
    pred_down = pred[2:, 1:-1]
    pred_left = pred[1:-1, :-2]
    pred_right = pred[1:-1, 2:]

    pred_diff = np.abs(pred_middle - pred_up) + np.abs(pred_middle - pred_down) + np.abs(pred_middle - pred_left) + np.abs(pred_middle - pred_right)

    return pred_diff.sum() / pred_diff.size



def make_boundary_data(cfg: ConfigRun, dataset: OpenMLDataset, name: str, gridsize: int):

    x_train, x_val, x_test, y_train, y_val, y_test, categorical_indicator = next(dataset.split_iterator())
    # variables 0 and 3 are the most important ones (as selected by random forest feature importance)
    features = [0, 3]

    x_train = x_train[:, features]
    x_val = x_val[:, features]
    x_test = x_test[:, features]
    

    x1 = (np.arange(gridsize) / gridsize - 0.5) * 4
    x2 = (np.arange(gridsize) / gridsize - 0.5) * 4

    x_mesh1, x_mesh2 = np.meshgrid(x1, x2)

    x_mesh1_col = einx.rearrange('h w -> (h w)', x_mesh1)
    x_mesh2_col = einx.rearrange('h w -> (h w)', x_mesh2)

    x_grid = np.zeros((gridsize**2, x_test.shape[1]))
    x_grid[:, 0] = x_mesh1_col
    x_grid[:, 1] = x_mesh2_col

    data = Data.from_standard_datasplits(
        x_train, 
        x_val, 
        x_test, 
        y_train, 
        y_val, 
        y_test, 
        task=cfg.task,
        early_stopping_data_split="VALID",
        early_stopping_max_samples=None
    )

    model = get_model(cfg, data.x_train_cut, data.y_train_cut, categorical_indicator)
    trainer = get_trainer(cfg, model, dataset.n_classes)
    trainer.train(data.x_train_cut, data.y_train_cut, data.x_val_earlystop, data.y_val_earlystop)
    prediction_metrics_test = trainer.evaluate(data.x_train_and_val, data.y_train_and_val, data.x_test, data.y_test)

    score_test = prediction_metrics_test.score

    preds_raw = trainer.predict(x_train, y_train, x_grid)
    preds = einx.softmax('n [c]', preds_raw)
    preds = einx.rearrange('(h w) c -> h w c', preds, w=gridsize)

    irreg_value = irregularity_value(preds)

    da = xr.Dataset(
        data_vars=dict(
            pred=(['x1', 'x2', 'class'], preds),
            x_test=(['observations', 'feature'], x_test),
            y_test=(['observations'], y_test),
        ),
        coords={
            'observation': np.arange(x_test.shape[0]),
            'feature': np.arange(x_test.shape[1]),
            'x1': x1, 
            'x2': x2, 
            'class': np.arange(dataset.n_classes)
        },
        attrs={
            'model': name,
            'acc': score_test,
            'irreg': irreg_value
        }
    )

    return da



def make_boundary_plot(da: xr.Dataset) -> plt.Figure:

    fig, ax = plt.subplots(1, 1, figsize=(10, 6))

    coolwarm = plt.get_cmap('coolwarm')
    coolwarm_trunc = truncate_colormap(coolwarm, 0.1, 0.9)

    x_mesh1, x_mesh2 = np.meshgrid(da.coords['x1'], da.coords['x2'])
    
    ax.pcolormesh(x_mesh1, x_mesh2, da['pred'].data[:, :, 0], cmap=coolwarm_trunc)
    ax.scatter(da['x_test'].data[:, 0], da['x_test'].data[:, 1], c=1-da['y_test'].data, s=0.1, cmap=coolwarm)
    ax.set_ylim(-2, 2)
    ax.set_xlim(-2, 2)
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    
    ax.set_title(f"{da.attrs['model']}", fontsize=24)

    text = f"accuracy  : {da.attrs['acc']:.3f} \ncomplexity: {da.attrs['irreg']:.4f}"
    props = dict(boxstyle='round', facecolor='white', alpha=0.5)
    ax.text(0.05, 0.17, text, transform=ax.transAxes, fontsize=18,
        verticalalignment='top', bbox=props)

    return fig

if __name__ == "__main__":

    cfg = ConfigPretrain.load(Path('outputs/runs/2024-07-24/01-09-32/config_pretrain.yaml'))
    cfg.output_dir = Path('outputs/runs/2024-07-26/09-19-48')
    cfg.devices = [torch.device('cuda:7')]
    cfg.hyperparams_finetuning['lr'] = 1.0e-5
    path_weights = Path('outputs/runs/2024-07-24/01-09-32/weights/model_step_44000.pt')

    decision_boundary_analysis(cfg, path_weights, task=DownstreamTask.ZEROSHOT)
    decision_boundary_analysis(cfg, path_weights, task=DownstreamTask.FINETUNE)

    # data = xr.open_dataset('outputs/runs/2024-06-02/15-38-34/decision_boundary/decision_boundary_finetune.nc')
    # fig = make_boundary_plot(data)
    # fig.savefig('outputs/runs/2024-06-02/15-38-34/decision_boundary/decision_boundary_finetune.png')