import json
import os

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

from gradiend.setups import BatchedTrainingDataset, Setup
from gradiend.setups.emotion import read_vad_templates
from gradiend.training.gradiend_training import train_for_configs


class TrainingDataset(BatchedTrainingDataset):
    def __init__(self, data, tokenizer, batch_size=None, is_generative=False, max_size=None, target_key='valence'):
        super().__init__(data=data, tokenizer=tokenizer, max_size=max_size, batch_size=batch_size,
                         batch_criterion='masked_word')
        self.target_key = target_key

    def __getitem__(self, idx):
        entry = super().__getitem__(idx) # Get the adjusted index for batching

        raw_text = entry['template'].replace('[MASK]', self.tokenizer.mask_token)

        arousal = entry['arousal']
        valence = entry['valence']
        #dominance = entry['dominance']
        label = (arousal, valence)
        #label = (int(arousal > 0), int(valence > 0), int(dominance > 0))
        label = (int(arousal > 0), int(valence > 0))
        label = entry[self.target_key] #> 0

        item_factual = self._create_item(raw_text, entry['masked_word'])
        item_counterfactual = self._create_item(raw_text, entry['antonym'])

        return {
            True: item_factual,
            False: item_counterfactual,
            'text': raw_text,
            'label': label,
            'binary_label': label > 0,
            'factual_target': entry['masked_word'],
            'counterfactual_target': entry['antonym'],
            'metadata': (arousal, valence), # todo update
        }


def create_training_dataset(tokenizer,
                            max_size=None,
                            batch_size=None,
                            neutral_data=False,
                            split='train',
                            neutral_data_prop=0.5,
                            is_generative=False,
                            key='valence',
                            single_word_texts=False,
                            ):
    # Load datasets
    #valence_data = pd.read_csv('data/emotion/vad_valence.csv')
    templates = read_vad_templates(split=split, use_antonyms=True)
    templates['binary_label'] = templates[key] > 0
    keys_terms = json.load(open(f'data/emotion/{key}_opposite_terms.json', 'r'))

    templates = templates[templates['masked_word'].isin(keys_terms)].reset_index(drop=True)

    if single_word_texts:
        templates['template'] = "[MASK]"
        # drop duplicates based on 'masked_word' and 'template'
        templates = templates.drop_duplicates(subset=['masked_word']).reset_index(drop=True)


    # templates.groupby('masked_word')['arousal'].first()

    if max_size is not None:
        # sample stratified by valence
        n_groups = templates['binary_label'].nunique()
        #templates = templates.groupby('binary_label').apply(lambda x: x.sample(min(len(x), max_size // n_groups))).reset_index(drop=True)

        def sample_with_max_label_variance(group, sample_size): # todo needed?
            if len(group) <= sample_size:
                return group

            # Greedy heuristic to maximize variance
            # Start with min and max label
            group_sorted = group.sort_values(key)
            selected = [group_sorted.iloc[0], group_sorted.iloc[-1]]

            remaining = group_sorted.iloc[1:-1].copy()
            while len(selected) < sample_size and not remaining.empty:
                best_row = None
                best_variance = -np.inf
                for idx, row in remaining.iterrows():
                    temp = pd.DataFrame(selected + [row])
                    var = temp[key].var()
                    if var > best_variance:
                        best_variance = var
                        best_row = idx
                selected.append(remaining.loc[best_row])
                remaining = remaining.drop(best_row)

            return pd.DataFrame(selected)

        # Apply to each group
        sample_size = max_size // n_groups
        #templates = (
        #    templates
        #    .groupby('binary_label', group_keys=False)
        #    .apply(lambda g: sample_with_max_label_variance(g, sample_size))
        #    .reset_index(drop=True)
        #)
        templates = templates.groupby('binary_label', group_keys=False).apply(lambda g: g.sample(min(len(g), sample_size))).reset_index(drop=True)

    else:
        # make sure both classes are balanced
        n_pos = templates['binary_label'].sum()
        n_neg = len(templates) - n_pos
        if n_pos > n_neg:
            templates = pd.concat([templates[templates['binary_label'] == 1].sample(n_neg),
                                   templates[templates['binary_label'] == 0]], ignore_index=True)
        else:
            templates = pd.concat([templates[templates['binary_label'] == 0].sample(n_pos),
                                   templates[templates['binary_label'] == 1]], ignore_index=True)
        print(f"Balanced dataset: {len(templates)} samples, {n_pos} positive, {n_neg} negative")

    key_values = templates[key].unique()
    print(f"Using {len(key_values)} unique {key} values: {key_values}")




    # Create custom dataset
    max_token_length = 128 if 'llama' in tokenizer.name_or_path.lower() else 48
    return TrainingDataset(templates,
                           tokenizer,
                           batch_size=batch_size,
                           is_generative=is_generative,
                           target_key=key,
                           )


class Emotion1DSetup(Setup):
    def __init__(self, key='valence'):
        super().__init__(key)
        self.target_key = key
        self.ctr = 0

    def create_training_data(self, *args, **kwargs):
        return create_training_dataset(*args, key=self.target_key, **kwargs)

    def evaluate(self, *args, **kwargs):
        result = super().evaluate(*args, **kwargs)
        score = result['score']
        encoded = result['encoded']
        encoded_by_class = result['encoded_by_class']
        mean_by_class = result['mean_by_class']

        x = []
        y = []
        for label, encoded_values in encoded_by_class.items():
            x.extend(encoded_values)
            parsed_label = label if isinstance(label, (int, float)) else float(label) if isinstance(label, (bool, np.bool_)) else float(label[1:-2])  # Handle tuple string labels
            y.extend([parsed_label] * len(encoded_values))

        y_binary = [1 if val > 0 else 0 for val in y]
        # plot with color based on y_binary
        cmap = ListedColormap(['blue', 'red']) or 'coolwarm'
        plt.scatter(x, y, c=y_binary, cmap=cmap, alpha=1.0)

        counterfactual_targets = result['counterfactual_target']
        # plot counterfactual targets as text labels
        #for i, target in enumerate(counterfactual_targets):
        #    plt.text(x[i][0], y[i], target, fontsize=6, alpha=0.7)

        plt.title(f"{self.target_key} Evaluation - Score: {score}")
        plt.xlabel('Encoded Feature')
        plt.ylabel(f'{self.target_key} Value')
        plt.grid()
        output = f'img/gradiend/emotion_{self.target_key}_evaluation_{self.ctr}_{score}.png'
        os.makedirs(os.path.dirname(output), exist_ok=True)
        plt.savefig(output)
        self.ctr += 1
        plt.show()

        return result

    def post_training(self, model_with_gradiend, **kwargs):
        pass

class ValenceSetup(Emotion1DSetup):
    def __init__(self):
        super().__init__(key='valence', )

class ArousalSetup(Emotion1DSetup):
    def __init__(self):
        super().__init__(key='arousal')



class Emotion2DSetup(Setup):
    def __init__(self, key='valence'):
        super().__init__(key, n_features=2)
        self.target_key = key
        self.ctr = 0

    def create_training_data(self, *args, **kwargs):
        return create_training_dataset(*args, key=self.target_key, **kwargs)



def train_valence(configs, version='arousal', activation='tanh'):
    if 'arousal' in version:
        setup = ArousalSetup()
    elif 'valence' in version:
        setup = ValenceSetup()
    elif 'dominance' in version:
        setup = Emotion1DSetup(key='dominance')
    else:
        raise ValueError(f"Unknown version: {version}")

    for id, config in configs.items():
        config['activation'] = activation
        config['delete_models'] = False
    train_for_configs(setup, configs, version=version, n=3)


def train_dual_valence(configs, version='arousal_2', activation='gelu'):
    if 'arousal' in version:
        setup = Emotion2DSetup(key='arousal')
    elif 'valence' in version:
        setup = Emotion2DSetup(key='valence')
    elif 'dominance' in version:
        setup = Emotion2DSetup(key='dominance')
    else:
        raise ValueError(f"Unknown version: {version}")

    for id, config in configs.items():
        config['activation'] = activation
        config['delete_models'] = False
    train_for_configs(setup, configs, version=version, n=5)



if __name__ == '__main__':
    configs = {
        #'distilbert-base-cased': dict(eval_max_size=1000, batch_size=16, n_evaluation=250, epochs=2, supervised=True, source='counterfactual', target='diff'),
        'distilbert-base-cased': dict(eval_max_size=1000, batch_size=16, n_evaluation=250, epochs=5, supervised=False, source='counterfactual', target='diff'),
        #'roberta-base': dict(eval_max_size=400, batch_size=16, n_evaluation=250, epochs=1, source='counterfactual', target='diff', lr=1e-6),
        #'distilbert-base-cased': dict(eval_max_size=200, batch_size=4, n_evaluation=250, epochs=5, source='factual', target='diff'),
        #'bert-base-cased': dict(eval_max_size=100, batch_size=32, n_evaluation=100, epochs=5, source='counterfactual', target='diff'),
    }

    configs_factual_source = {k: {**v, 'source': 'factual'} for k, v in configs.items()}

    #train_valence(configs, version='arousal_', activation='tanh')
    #train_valence(configs, version='valence_', activation='tanh')
    try:
        train_valence(configs, version='dominance_', activation='tanh')
    except Exception as e:
        print(f"Error during training: {e}")

    #train_valence(configs, version='valence_tanh_supervised_cf', activation='tanh')
    #train_valence(configs, version='arousal_tanh_supervised_cf', activation='tanh')
    train_valence(configs, version='dominance_tanh_supervised_cf', activation='tanh')
    #train_valence(configs, version='arousal_gelu_supervised_cf', activation='gelu')
    #train_valence(configs, version='valence_gelu_supervised_cf', activation='gelu')
    #train_valence(configs, version='valence_relu_v6', activation='relu')
    #train_valence(configs, version='arousal_relu', activation='relu')

    exit(1)

    train_valence(configs_factual_source, version='valence_tanh_supervised_f', activation='tanh')
    train_valence(configs_factual_source, version='arousal_tanh_supervised_f', activation='tanh')
    #train_valence(configs_factual_source, version='arousal_gelu_supervised_f', activation='gelu')
    #train_valence(configs_factual_source, version='valence_gelu_supervised_f', activation='gelu')

    not_supervised = {k: {**v, 'supervised': False} for k, v in configs.items()}
    train_valence(not_supervised, version='valence_tanh_cf', activation='tanh')
    train_valence(not_supervised, version='arousal_tanh_cf', activation='tanh')