import os

from matplotlib.colors import to_rgba
from torch.utils.data import Dataset

from gradiend.setups import Setup, BatchedTrainingDataset

from torch.utils.data import Dataset
import torch

from gradiend.setups.emotion.data import read_vad, read_vad_templates, read_antonyms
from gradiend.training.gradiend_training import train_for_configs



class TrainingDataset(BatchedTrainingDataset):
    def __init__(self, data, tokenizer, batch_size=None, is_generative=False, max_size=None):
        super().__init__(data=data,
                         tokenizer=tokenizer,
                         max_size=max_size,
                         batch_size=batch_size,
                         batch_criterion='masked_word',
                         max_length=80,
                         )
        self.key_source = 'masked_word'
        self.key_target = 'masked_word'


        self.key_mapping = {
            'factual': 'masked_word',
        }

        if 'antonym' in self.data.columns:
            self.key_mapping['counterfactual'] = 'antonym'

        if is_generative:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.is_generative = is_generative

        if self.is_generative:
            raise NotImplementedError("Generative training is not implemented for this dataset.")



    def __getitem__(self, idx):
        entry = super().__getitem__(idx) # Get the adjusted index for batching

        raw_text = entry['template']

        arousal = entry['arousal']
        valence = entry['valence']
        #dominance = entry['dominance']
        label = (arousal, valence)
        #label = (int(arousal > 0), int(valence > 0), int(dominance > 0))
        label = (arousal, valence)
        binary_label = (int(arousal > 0), int(valence > 0))
        items = {}

        for key in {'factual', 'counterfactual'}:
            item = self._create_item(raw_text, entry[self.key_mapping[key]])

            items[key] = item

        item_factual = items['factual']
        if 'counterfactual' in items:
            item_counterfactual = items['counterfactual']
        else:
            item_counterfactual = item_factual.copy()

        return {
            True: item_factual,
            False: item_counterfactual,
            'text': raw_text,
            'factual_target': entry['masked_word'],
            'counterfactual_target': entry.get('antonym', entry['masked_word']),
            'label': label,
            'binary_label': binary_label,
            'metadata': (arousal, valence), # todo update
        }


import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import os
import ast  # For parsing string like '(0, 1)' to tuple

def plot_encoded_by_class(
    result,
    title="Encoded Points by Class",
    output=None,
    #aggregations=('sum', 'mean', 'abs_sum', 'max_abs'),
    aggregations=('mean_abs'),
    marker_by_aggregation=None,
    use_binary_keys_for_styling=True,
):
    encoded_by_class = result['encoded_by_class']
    mean_by_class = result.get('mean_by_class', None)

    if isinstance(aggregations, str):
        aggregations = [aggregations]

    default_markers = ['o', 's', 'D', '^', 'v', 'P', '*', 'X']
    if marker_by_aggregation is None:
        marker_by_aggregation = {
            agg: default_markers[i % len(default_markers)]
            for i, agg in enumerate(aggregations)
        }

    binary_colors = ['blue', 'red']
    binary_markers = ['o', 's']

    def parse_label(label_str):
        try:
            return ast.literal_eval(label_str)
        except Exception:
            return None

    def is_binary_key_str(label_str):
        parsed = parse_label(label_str)
        return (
            isinstance(parsed, tuple) and
            len(parsed) == 2 and
            all(x in [0, 1] for x in parsed)
        )

    binary_key_mode = (
        use_binary_keys_for_styling and
        all(is_binary_key_str(k) for k in encoded_by_class)
    )

    def get_color_marker(label_str, idx):
        if binary_key_mode:
            parsed = parse_label(label_str)
            color = binary_colors[parsed[0]]
            marker = binary_markers[parsed[1]]
        else:
            color = plt.cm.tab10(idx % 10)
            marker = default_markers[idx % len(default_markers)]
        return color, marker

    def aggregate(arr, exclude_idx, method):
        other_dims = np.delete(arr, exclude_idx, axis=1)
        if method == 'sum':
            return np.sum(other_dims, axis=1)
        elif method == 'mean':
            return np.mean(other_dims, axis=1)
        elif method == 'abs_sum':
            return np.sum(np.abs(other_dims), axis=1)
        elif method == 'max_abs':
            return np.max(np.abs(other_dims), axis=1)
        elif method == 'mean_abs':
            return np.mean(np.abs(other_dims), axis=1)
        else:
            raise ValueError(f"Unsupported aggregation: {method}")

    # Determine dimensionality
    first_label = next(iter(encoded_by_class))
    first_point = np.array(encoded_by_class[first_label])[0]
    dim = len(first_point)

    # Direct 2D plot
    if dim == 2 and len(aggregations) == 1:
        plt.figure(figsize=(8, 6))
        for i, (label, points) in enumerate(encoded_by_class.items()):
            points = np.array(points)
            color, marker = get_color_marker(label, i)
            plt.scatter(points[:, 0], points[:, 1], label=label, color=color, marker=marker, alpha=1)

        plt.title(title)
        plt.xlabel("Dim 1")
        plt.ylabel("Dim 2")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()

        if output:
            os.makedirs(os.path.dirname(output), exist_ok=True)
            plt.savefig(output)
        plt.show()

        data = result['encoded']
        targets = result['counterfactual_target']
        # Split labels and encodings
        labels = [l for _, l in data]
        encodings = [e for e, _ in data]

        lab1, lab2 = zip(*labels)
        enc1, enc2 = zip(*encodings)

        # Plotting
        fig, axs = plt.subplots(1, 2, figsize=(12, 5), sharey=False)

        for i, (label_vals, ax) in enumerate(zip([lab1, lab2], axs)):
            ax.scatter(enc1, label_vals, label="Encoding Dim 1", alpha=0.7)
            ax.scatter(enc2, label_vals, label="Encoding Dim 2", alpha=0.7)

            for x1, x2, y, txt in zip(enc1, enc2, label_vals, targets):
                ax.text(x1, y, txt, fontsize=6)
                ax.text(x2, y, txt, fontsize=6)

            ax.set_xlabel("Encoding Value")
            ax.set_ylabel(f"Label Dim {i + 1}")
            ax.set_title(f"Encodings vs Label Dim {i + 1}")
            ax.legend()

        plt.tight_layout()

        if output:
            output2 = output.replace('.pdf', '_2d.pdf')
            os.makedirs(os.path.dirname(output2), exist_ok=True)
            plt.savefig(output2)

        plt.show()

    else:
        def adjust_alpha_color(base_color, alpha_factor):
            r, g, b, _ = to_rgba(base_color)
            return (r * alpha_factor, g * alpha_factor, b * alpha_factor)

        max_aggregations = 4
        alpha_factors = [1.0, 0.8, 0.6, 0.4]

        assert len(aggregations) <= max_aggregations, "Too many aggregations (>4) not supported"

        for d in range(dim):
            plt.figure(figsize=(8, 6))
            for agg_idx, agg in enumerate(aggregations):
                for i, (label, points) in enumerate(encoded_by_class.items()):
                    points = np.array(points)
                    x_vals = aggregate(points, exclude_idx=d, method=agg)
                    y_vals = points[:, d]
                    base_color, marker = get_color_marker(label, i)
                    shaded_color = adjust_alpha_color(base_color, alpha_factors[agg_idx])

                    plt.scatter(
                        x_vals,
                        y_vals,
                        label=f"{label} ({agg})",
                        color=shaded_color,
                        marker=marker,
                        alpha=0.8
                    )

            if mean_by_class:
                for agg in aggregations:
                    for i, (label, mean_vector) in enumerate(mean_by_class.items()):
                        if label not in encoded_by_class:
                            continue
                        mean_x = aggregate(np.array([mean_vector]), exclude_idx=d, method=agg)[0]
                        mean_y = mean_vector[d]
                        base_color, marker = get_color_marker(label, i)
                        shaded_color = adjust_alpha_color(base_color, alpha_factors[agg_idx])
                        plt.scatter(
                            mean_x,
                            mean_y,
                            color=shaded_color,
                            marker=marker,
                            s=100,
                            edgecolor='black',
                            linewidth=1.5,
                        )

            plt.title(f"{title} (Dim {d})")
            plt.xlabel(f"Aggregated (excluding dim {d})")
            plt.ylabel(f"Dim {d}")
            plt.legend()
            plt.grid(True)
            plt.tight_layout()

            if output:
                os.makedirs(os.path.dirname(output), exist_ok=True)
                base, ext = os.path.splitext(output)
                plt.savefig(f"{base}_dim{d}{ext}")
            plt.show()



class EmotionSetup(Setup):

    def __init__(self, id='emotion', n_features=2):
        super().__init__(id, n_features=n_features)


    def create_training_data(self,
                             tokenizer,
                             max_size=None,
                             batch_size=None,
                             split='train',
                             is_generative=False,
                             max_versions_per_masked_word=None,
                             filter_by_tokenizer=True,
                             filter_by_min_vad=False,
                             ):
        max_versions_per_masked_word = max_versions_per_masked_word or batch_size * 100

        if filter_by_tokenizer:
            filter_by_tokenizer = tokenizer

        raw_data = read_vad_templates(split=split, max_versions_per_masked_word=max_versions_per_masked_word, max_size=max_size, filter_by_tokenizer=filter_by_tokenizer)

        if filter_by_min_vad:
            raw_data = raw_data[(raw_data['valence'].abs() > 0.5) | (raw_data['arousal'].abs() > 0.5)]
            print(f"Filtered data by min VAD: {len(raw_data)} entries remaining.")



        dataset = TrainingDataset(
            data=raw_data,
            tokenizer=tokenizer,
            batch_size=batch_size,
        )

        return dataset

    def create_eval_data(self, *args, split='val', max_size=None, **kwargs):
        main = super().create_eval_data(*args, max_versions_per_masked_word=10, split=split, max_size=max_size, **kwargs)

        return main

        train_data = super().create_eval_data(*args, max_versions_per_masked_word=10, **kwargs, split='train', max_size=100)

        return {
            'main': main,
            'train': train_data,
        }


        #terms = []

        #source = kwargs['source']
        #terms = main[source]



        #vad = read_vad()
        # pick one term out of terms for valence/arousal>0, valence>0 and arousal<0, and so on
        #vad_arousal_pos_valence_pos = vad[(vad['arousal'] > 0) & (vad['valence'] > 0)]
        #vad_arousal_neg_valence_pos = vad[(vad['arousal'] < 0) & (vad['valence'] > 0)]
        #vad_arousal_pos_valence_neg = vad[(vad['arousal'] > 0) & (vad['valence'] < 0)]
        #vad_arousal_neg_valence_neg = vad[(vad['arousal'] < 0) & (vad['valence'] < 0)]




    def evaluate(self, *args, **kwargs):
        result = super().evaluate(*args, **kwargs)
        #score = result['score']
        #encoded = result['encoded']
        #encoded_by_class = result['encoded_by_class']
        #mean_by_class = result['mean_by_class']

        #print(f"Emotion evaluation score: {score:.4f}")
        #print(f"Encoded values: {encoded[:10]}...")
        #print("Encoded values by class:")
        #for label, values in encoded_by_class.items():
        #    print(f"Label {label}: {values[:10]}... (mean: {mean_by_class[label]:.4f})")

        #plot_encoded_by_class(encoded_by_class, mean_by_class=mean_by_class, title="Emotion Encoded Points by Class")

        return result


class MultiDimEmotionSetup(EmotionSetup):

    def __init__(self, n_features=2):
        super().__init__(f'emotion-{n_features}', n_features=n_features)


    def evaluate(self, model_with_gradiend, eval_data, eval_batch_size=32, config=None, training_stats=None):
        # one hot encode the labels
        #if isinstance(eval_data['labels'][0], int):
        #    num_classes = max(eval_data['labels']) + 1
        #    eval_data['labels'] = np.eye(num_classes)[eval_data['labels']]

        result = super().evaluate(model_with_gradiend, eval_data, eval_batch_size=eval_batch_size, config=config, training_stats=training_stats)
        score = result['score']
        encoded = result['encoded']
        encoded_by_class = result['encoded_by_class']
        mean_by_class = result['mean_by_class']

        output_name = f'training_{str(model_with_gradiend.gradiend.encoder[1])}.pdf'
        if config and 'output' in config:
            base_output = config['output']
            global_step = training_stats.get('global_step', None)
            output = f'{base_output}/{global_step}_{output_name}'
        else:
            output = f'img/{output_name}'
        plot_encoded_by_class(result, title=f"Score {score}", output=output)

        return result


def multi_dim_training(configs, version=None, activation='tanh', n_features=4):
    setup = MultiDimEmotionSetup(n_features=n_features)
    for id, config in configs.items():
        config['activation'] = activation
        config['delete_models'] = True
    train_for_configs(setup, configs, version=version, n=3)

print('Test')

if __name__ == '__main__':
    print('Test')
    configs = {
        'distilbert-base-cased': dict(eval_max_size=1000, batch_size=8, n_evaluation=250, epochs=1, source='counterfactual', target='diff', max_iterations=20000),
        'roberta-base': dict(eval_max_size=200, batch_size=4, n_evaluation=100, epochs=1, source='counterfactual', target='diff', max_iterations=20000),
    }



    multi_dim_training(configs, activation='tanh', n_features=2, version='tanh_2_v7')
    multi_dim_training(configs, activation='tanh', n_features=10, version='tanh_10_v7')
    multi_dim_training(configs, activation='tanh', n_features=5, version='tanh_5_v7')
#    multi_dim_training(configs, activation='gelu', n_features=10, version='gelu_10_v4')
#    multi_dim_training(configs, activation='relu', n_features=10, version='relu_10_v4')