import os
import json
import random
from copy import copy
import torch
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl
import math

from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset

from collections import defaultdict

from .config import config
from .model import LlamaEmbeddingModel


import os
import torch
from torch.utils.data import Dataset
import lmdb
import pickle
import hashlib
import numpy as np

class ClassifierDataset(Dataset):
    def __init__(self, split : str):
        assert split in ['train', 'val', 'test'], 'Invalid split'
        
        self.split = split
        
        self.type = config.get("training_class", "type")
        
        self.dataset = LlamaEmbeddingDataset(split)
        #self.tokenizer = AutoTokenizer.from_pretrained(config.get("model", "model_id"))
        
        self.used_nifty_keys = self.get_used_nifty_keys()
        
        self.projector = LlamaEmbeddingModel()
        
        # Load the model's weights from hte path in the config
        self.weight_path = config.get("training_class", "projector_weights_path")
        
        # Load the checkpoint file first
        checkpoint = torch.load(self.weight_path)
        model_state_dict = checkpoint['state_dict']
        self.projector.load_state_dict(model_state_dict)
        
        self.projector.eval()
        
    def get_used_nifty_keys(self):
        # Retrieve the training space
        training_space = self.dataset.dataset.training_space
        
        # Create a dictionary to store ids grouped by label
        label_to_ids = defaultdict(list)
        
        # Iterate over the training space and group ids by label
        for entry in training_space:
            label = entry["label"]
            nifty_id = entry["id"]
            label_to_ids[label].append(nifty_id)
        
        # Find the minimum count across all labels (smallest class size)
        min_label_count = min(len(ids) for ids in label_to_ids.values())
        
        # Collect an equal number of ids from each label
        balanced_ids = []
        for label, ids in label_to_ids.items():
            # Randomly sample min_label_count ids from each label group
            sampled_ids = random.sample(ids, min_label_count)
            balanced_ids.extend(sampled_ids)
        
        return balanced_ids

        
    def __len__(self):
        return len(self.used_nifty_keys)
    
    def __getitem__(self, idx):
        nifty_key = self.used_nifty_keys[idx]
        
        for i, d in enumerate(self.dataset.dataset.training_space):
            if d["id"] == nifty_key:
                data_row = self.dataset.dataset.training_space[i]
        
        nifty_key_int = self.generate_int_key(nifty_key)
        key = generate_key(nifty_key_int)
        
        llm_embeddings = self.dataset.retrieve_embedding(key)
        llm_embeddings = torch.tensor(llm_embeddings, dtype=torch.float32)
        
        # Normalize the embeddings
        llm_embeddings /= torch.norm(llm_embeddings, p=2)
        
        label = data_row["label"]
        
        label_dict = {
            "Fall": 0,
            "Neutral": 1,
            "Rise": 2
        }
        
        label_val = label_dict[label]
        
        if self.type == "llm":
            return (llm_embeddings, label_val)
        
        projection = self.projector(llm_embeddings.unsqueeze(0)).squeeze(0)
        
        if self.type == "proj":
            return (projection, label_val)
        elif self.type == "both":
            model_input = torch.cat((llm_embeddings, projection), dim=0)
            return (model_input, label_val)
        
        raise ValueError("Invalid type")
        
        
    def generate_int_key(self, nifty_key):
        split_anchor_id = nifty_key.split("_")
        if len(split_anchor_id) == 2:
            if int(split_anchor_id[1]) == 0:
                anchor_id_tensor = -1
            else:
                anchor_id_tensor = int(split_anchor_id[1]) * 1_000_000_000
        elif len(split_anchor_id) == 3:
            anchor_id_tensor = int(split_anchor_id[1]) * 1_000_000_000 + int(split_anchor_id[2])
        else:
            raise ValueError(f"Invalid ID: {nifty_key}")
        
        return anchor_id_tensor
        
class EmbeddingDataset(Dataset):
    def __init__(self, split, return_newsline_string = False):
        self.split = split
        
        self.return_newsline_string = return_newsline_string
        
        data_path = config.get("data_gen", "training_space_path")[split]
        
        # Load configuration
        encoder_base_model = config.get("embedding_training", "encoder_base_model")
        self.max_token_count = config.get("embedding_training", "max_token_count")
        self.compare_batch_size = config.get("embedding_training", "compare_batch_size")
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(encoder_base_model)

        # Load data
        with open(data_path, 'r') as f:
            self.data = json.load(f)
        
    def __len__(self):
        
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        # Base string (assumed to be in 'base_news')
        base_string = "\n".join(item.get('base_news', []))
        
        # Augmented strings and similarity scores
        augmented_strings = []
        similarity_scores = []

        augmented_newslines = item.get('augmented_newslines', [])
        for aug in augmented_newslines:
            # Each 'aug' is a dictionary
            # Extract 'newsline', which is a list of strings
            newsline_list = aug.get('newsline', [])
            similarity = aug.get('similarity', 0.0)
            newsline_str = "\n".join(newsline_list)
            augmented_strings.append(newsline_str)
            similarity_scores.append(similarity)
        
        # Limit to compare_batch_size
        if len(augmented_strings) > self.compare_batch_size:
            augmented_strings = augmented_strings[:self.compare_batch_size]
            similarity_scores = similarity_scores[:self.compare_batch_size]

        # Tokenize base string
        base_tokens = self.tokenizer(
            base_string,
            padding='max_length',
            truncation=True,
            max_length=self.max_token_count,
            return_tensors='pt'
        )

        # Tokenize augmented strings
        augmented_tokens = self.tokenizer(
            augmented_strings,
            padding='max_length',
            truncation=True,
            max_length=self.max_token_count,
            return_tensors='pt'
        )

        # Convert similarity scores to tensor
        similarity_scores = torch.tensor(similarity_scores)

        # Extract and squeeze base tokens
        base_input_ids = base_tokens['input_ids'].squeeze(0)  # Shape: (seq_length,)
        base_attention_mask = base_tokens['attention_mask'].squeeze(0)  # Shape: (seq_length,)
        base_token_type_ids = base_tokens.get('token_type_ids', None)
        if base_token_type_ids is not None:
            base_token_type_ids = base_token_type_ids.squeeze(0)  # Shape: (seq_length,)

        # Augmented tokens (no need to squeeze)
        augmented_input_ids = augmented_tokens['input_ids']  # Shape: (num_augmented, seq_length)
        augmented_attention_mask = augmented_tokens['attention_mask']  # Shape: (num_augmented, seq_length)
        augmented_token_type_ids = augmented_tokens.get('token_type_ids', None)
        
        # anchor label value
        if config.get("dataset", "dataset") == "bigdata22" or config.get("dataset", "dataset") == "imdb":
            anchor_label = item.get("answer")
        else:
            anchor_label = item.get("label")
        
        if config.get("dataset", "dataset") == "imdb" or config.get("dataset", "dataset") == "nifty":
            anchor_value = {
                "Rise": 0,
                "Neutral": 1,
                "Fall": 2
            }[anchor_label]
        elif config.get("dataset", "dataset") == "bigdata22":
            anchor_value = {
                "Fall": 0,
                "Rise": 1
            }[anchor_label]
            
        if self.return_newsline_string:
            return {
                "base_input_ids": base_input_ids,
                "base_attention_mask": base_attention_mask,
                "base_token_type_ids": base_token_type_ids,
                "augmented_input_ids": augmented_input_ids,
                "augmented_attention_mask": augmented_attention_mask,
                "augmented_token_type_ids": augmented_token_type_ids,
                "similarity_scores": similarity_scores,
                "anchor_label": anchor_value,
                "base_newsline": base_string,
                "augmented_newslines": augmented_strings
            }

        return {
            "base_input_ids": base_input_ids,
            "base_attention_mask": base_attention_mask,
            "base_token_type_ids": base_token_type_ids,
            "augmented_input_ids": augmented_input_ids,
            "augmented_attention_mask": augmented_attention_mask,
            "augmented_token_type_ids": augmented_token_type_ids,
            "similarity_scores": similarity_scores,
            "anchor_label": anchor_value
        }

    

class LlamaEmbeddingDataset(Dataset):
    def __init__(self, split):
        """
        Dataset class to load precomputed embeddings from the LMDB database.
        
        Args:
            split (str): One of 'train', 'test', 'val'.
        """
        self.split = split
        self.dataset = SimilaritySpaceDataset(split)  # Your original dataset class
        self.lmdb_path = config.get("data_gen", "llama_precompute_path")
        
        self.intialize_lmdb()
        
    def intialize_lmdb(self):
        self.env = lmdb.open(
            self.lmdb_path,
            readonly=True,
            lock=False,
            readahead=False,
            max_dbs=3,
        )
        
        # Open the databases without a transaction
        self.db = self.env.open_db(b'embeddings')
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Retrieve data from the original dataset
        current_row = self.dataset[idx]
        
        int_anchor_key = current_row["anchor_key"].detach().cpu().tolist()[0]
        anchor_key = generate_key(int_anchor_key)
        
        anchor_embedding = self.retrieve_embedding(anchor_key)
        
        if anchor_embedding is None:
            raise ValueError(f"Anchor embedding not found for index {idx}")
        
        augmented_embeddings = []
        for augmented_key in current_row["augmented_keys"]:
            augmented_key = generate_key(int(augmented_key))
            augmented_embedding = self.retrieve_embedding(augmented_key)
            
            if augmented_embedding is None:
                raise ValueError(f"Augmented embedding not found for index {idx}")
            
            augmented_embeddings.append(augmented_embedding)
            
        augmented_similiarities = current_row["augmented_similarities"]
        
        # Convert embeddings to tensors
        anchor_embedding = torch.tensor(anchor_embedding, dtype=torch.float32)
        augmented_embeddings = torch.tensor(augmented_embeddings, dtype=torch.float32)  # Shape: [num_augmented, hidden_size]
        
        # Normalize all embeddings to unit length
        anchor_embedding /= torch.norm(anchor_embedding, p=2)
        augmented_embeddings /= torch.norm(augmented_embeddings, p=2, dim=1, keepdim=True)
        
        # Get the anchor label
        anchor_label = self.get_anchor_label(int_anchor_key)
        
        # Return data as a dictionary
        sample = {
            'anchor_embedding': anchor_embedding,
            'augmented_embeddings': augmented_embeddings,
            'augmented_similarities': augmented_similiarities,
            'anchor_label': anchor_label
        }
        return sample
    
    def get_anchor_label(self, anchor_key):
        nifty_string = self.reverse_anchor_id(anchor_key)
        
        # Check if the nifty string is an original
        nifty_strip = nifty_string.split("_")
        if len(nifty_strip) == 3:
            return -1
        else:
            return_dict = {
                "Fall": 0,
                "Neutral": 1,
                "Rise": 2
            }
            
            # Nifty string is an original
            nifty_list = self.dataset.training_space
            for d in nifty_list:
                if d.get("id") == nifty_string:
                    return return_dict[d["label"]]
                    

    
    def reverse_anchor_id(self, anchor_id_tensor):
        # Check if the anchor_id_tensor is -1, which corresponds to nifty_0
        if anchor_id_tensor == -1:
            return "nifty_0"
        
        # Extract the main part (nifty_1, nifty_2, etc.)
        main_part = anchor_id_tensor // 1_000_000_000
        remainder = anchor_id_tensor % 1_000_000_000

        # If remainder is 0, return a single-level ID (e.g., nifty_1)
        if remainder == 0:
            return f"nifty_{main_part}"
        else:
            # Otherwise, return a two-level ID (e.g., nifty_1_0)
            return f"nifty_{main_part}_{remainder}"

    
    def retrieve_embedding(self, key):
        # Note, has to be of shape (512)
        with self.env.begin(write=False) as txn:
            data = txn.get(key, db=self.db)
            if data is not None:
                embedding = pickle.loads(data)
                return embedding
            else:
                return None

def generate_key(key_int: int):
    """
    Generates a unique key based on nifty_id and augmented_idx.

    Parameters:
    - nifty_id: The unique ID string for the sample.
    - augmented_idx: The index of the augmented sample. Use None for the anchor.

    Returns:
    - key: The generated key as bytes, suitable for use in LMDB.
    """
    key = hashlib.sha256(str(key_int).encode('utf-8')).hexdigest().encode('ascii')
    return key


# Dataset for pytorch lightning
class SimilaritySpaceDataset(Dataset):
    PRIORS = [
        "Examine market data and news headlines dated DATE to project the direction of the $SPY index. If the expected shift is below PCTCHANGE, finalize with 'Neutral'. Present a response as 'Fall', 'Rise', or 'Neutral', along with the anticipated percentage change in a newline.",
        "To predict the $SPY index's direction, analyze market data and news headlines from DATE. If the expected alteration is less than PCTCHANGE, end with 'Neutral'. Offer a reply as 'Rise', 'Neutral', or 'Fall', coupled with the forecasted percentage change in a newline.",
        "Forecast the $SPY index's movement by scrutinizing market data and news headlines from DATE. If the expected adjustment is below PCTCHANGE, conclude with 'Neutral'. Deliver a response as 'Fall', 'Rise', or 'Neutral', alongside the predicted percentage change in a newline.",
        "Project the $SPY index's trajectory by assessing market data and news headlines from DATE. If the envisaged modification is under PCTCHANGE, end with 'Neutral'. Furnish a response as 'Neutral', 'Fall', or 'Rise', including the forecasted percentage change in a newline.",
        "Analyze market data and news headlines dated DATE to forecast the direction of the $SPY index. Conclude with 'Neutral' if the projected change is below PCTCHANGE. Provide a response as 'Rise', 'Fall', or 'Neutral', along with the anticipated percentage change in a newline.",
        "Assess market data and news headlines from DATE to predict the $SPY index's direction. Finalize with 'Neutral' if the expected shift is less than PCTCHANGE. Offer a reply as 'Fall', 'Neutral', or 'Rise', along with the forecasted percentage change in a newline.",
        "Predict the $SPY index's movement by examining market data and news headlines from DATE. Conclude with 'Neutral' if the anticipated adjustment is under PCTCHANGE. Present a response as 'Neutral', 'Fall', or 'Rise', with the predicted percentage change in a newline.",
        "Anticipate the direction of the $SPY index by analyzing market data and news headlines from DATE. End with 'Neutral' if the expected change is below PCTCHANGE. Supply a response as 'Rise', 'Neutral', or 'Fall', along with the forecasted percentage change in a newline.",
        "Examine market data and news headlines dated DATE to project the $SPY index's direction. Conclude with 'Neutral' if the envisaged shift is under PCTCHANGE. Deliver a response as 'Neutral', 'Rise', or 'Fall', along with the predicted percentage change in a newline.",
        "To predict the direction of the $SPY index, analyze market data and news headlines from DATE. End with 'Neutral' if the expected alteration is below PCTCHANGE. Provide a reply as 'Fall', 'Neutral', or 'Rise', along with the anticipated percentage change in a newline.",
        "Forecast the $SPY index's trajectory by scrutinizing market data and news headlines from DATE. Conclude with 'Neutral' if the expected adjustment is less than PCTCHANGE. Furnish a response as 'Rise', 'Fall', or 'Neutral', including the forecasted percentage change in a newline.",
        "Project the $SPY index's movement by assessing market data and news headlines from DATE. Conclude with 'Neutral' if the envisaged modification is under PCTCHANGE. Offer a reply as 'Neutral', 'Rise', or 'Fall', along with the predicted percentage change in a newline.",
        "Analyze market data and news headlines dated DATE to forecast the direction of the $SPY index. Deliver a response as 'Fall', 'Neutral', or 'Rise', along with the anticipated percentage change in a newline. Conclude with 'Neutral' if the projected change is below PCTCHANGE.",
        "Assess market data and news headlines from DATE to predict the $SPY index's direction. Offer a reply as 'Rise', 'Fall', or 'Neutral', along with the forecasted percentage change in a newline. Finalize with 'Neutral' if the expected shift is less than PCTCHANGE.",
        "Predict the $SPY index's movement by examining market data and news headlines from DATE. Supply a response as 'Neutral', 'Rise', or 'Fall', with the predicted percentage change in a newline. Conclude with 'Neutral' if the anticipated adjustment is under PCTCHANGE.",
        "Anticipate the direction of the $SPY index by analyzing market data and news headlines from DATE. Deliver a response as 'Fall', 'Neutral', or 'Rise', along with the forecasted percentage change in a newline. End with 'Neutral' if the expected change is below PCTCHANGE.",
        "Examine market data and news headlines dated DATE to project the $SPY index's direction. Present a response as 'Rise', 'Fall', or 'Neutral', along with the predicted percentage change in a newline. Conclude with 'Neutral' if the envisaged shift is under PCTCHANGE.",
        "To predict the direction of the $SPY index, analyze market data and news headlines from DATE. Provide a reply as 'Neutral', 'Rise', or 'Fall', along with the anticipated percentage change in a newline. End with 'Neutral' if the expected alteration is below PCTCHANGE.",
        "Forecast the $SPY index's trajectory by scrutinizing market data and news headlines from DATE. Furnish a response as 'Fall', 'Neutral', or 'Rise', including the forecasted percentage change in a newline. Conclude with 'Neutral' if the expected adjustment is less than PCTCHANGE.",
        "Project the $SPY index's movement by assessing market data and news headlines from DATE. Offer a reply as 'Rise', 'Fall', or 'Neutral', along with the predicted percentage change in a newline. Conclude with 'Neutral' if the envisaged modification is under PCTCHANGE."
    ]
    
    def __init__(self, split : str):
        assert split in ['train', 'val', 'test'], 'Invalid split'
        
        self.split = split
        
        # TEMP: Change to allow other tokenizers
        model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Set the padding token explicitly if it's not set by default
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token  # Use EOS as PAD token
        
        # Load the dataset
        self.training_space = self.load_training_space()
        
        # Defining the shape of the data
        self.number_of_samples = len(self.training_space)
        self.number_of_augmented_newslines_per_sample = len(self.training_space[0]["augmented_newslines"])
        
    def load_training_space(self):
        self.train_space_path = config.get("data_gen", f"training_space_path")[self.split]
        
        with open(self.train_space_path, 'r') as file:
            
            training_space = json.load(file)
            
        return training_space

    def __len__(self):
        return self.number_of_samples * (self.number_of_augmented_newslines_per_sample + 1)

    def __getitem__(self, idx):
        anchor_newsline, augmented_newslines, anchor_id, augmented_ids, simliarity_vector = self.get_newslines_info(idx)
        
        # Creating the full string newslines
        anchor_newsline = self.create_newsline(anchor_newsline)
        augmented_newslines = [self.create_newsline(news_strings) for news_strings in augmented_newslines]
        
        # Tokenizing using the tokenizer with padding and truncation
        max_length = config.get("training", "max_length")  # Define the maximum sequence length
        anchor_tokenized = self.tokenizer(anchor_newsline, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
        augmented_tokenized = [self.tokenizer(newsline, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
                               for newsline in augmented_newslines]
        
        # Extracting the input_ids and attention masks
        anchor_input_ids = anchor_tokenized["input_ids"].squeeze(0)  # Removing extra batch dimension
        anchor_attention_mask = anchor_tokenized["attention_mask"].squeeze(0)
        
        augmented_input_ids = torch.stack([item["input_ids"].squeeze(0) for item in augmented_tokenized])
        augmented_attention_masks = torch.stack([item["attention_mask"].squeeze(0) for item in augmented_tokenized])
        
        # Generate a torch tensor that contains the anchor id and the augmented ids
        anchor_id_tensor, augmented_ids_tensor = self.get_ids_tensor(anchor_id, augmented_ids)
        
        # Turning into torch tensors and sending to device
        device = config.get("training", "device")
        
        compiled_data = {
            "anchor_ids": anchor_input_ids.to(device),
            "anchor_attention_mask": anchor_attention_mask.to(device),
            "augmented_ids": augmented_input_ids.to(device),
            "augmented_attention_mask": augmented_attention_masks.to(device),
            "augmented_similarities": torch.tensor(simliarity_vector).to(device),
            "anchor_key": anchor_id_tensor.to(device),
            "augmented_keys": augmented_ids_tensor.to(device)
        }

        return compiled_data
    
    def get_ids_tensor(self, anchor_id, augmented_ids):
        augmented_id_list = []
        for augmented_id in augmented_ids:
            split_id = augmented_id.split("_")
            if len(split_id) == 2:
                if int(split_id[1]) == 0:
                    id_tensor = -1
                else:
                    id_tensor = int(split_id[1]) * 1_000_000_000
            elif len(split_id) == 3:
                id_tensor = int(split_id[1]) * 1_000_000_000 + int(split_id[2])
            else:
                raise ValueError(f"Invalid ID: {augmented_id}")
            
            augmented_id_list.append(id_tensor)
        
        split_anchor_id = anchor_id.split("_")
        if len(split_anchor_id) == 2:
            if int(split_anchor_id[1]) == 0:
                anchor_id_tensor = -1
            else:
                anchor_id_tensor = int(split_anchor_id[1]) * 1_000_000_000
        elif len(split_anchor_id) == 3:
            anchor_id_tensor = int(split_anchor_id[1]) * 1_000_000_000 + int(split_anchor_id[2])
        else:
            raise ValueError(f"Invalid ID: {anchor_id}")
        
        return torch.tensor([anchor_id_tensor]), torch.tensor(augmented_id_list)
            
    
    def get_newslines_info(self, idx):
        nifty_row_idx = idx // (self.number_of_augmented_newslines_per_sample + 1)
        augmented_idx = idx % (self.number_of_augmented_newslines_per_sample + 1)
        
        nifty_row = self.training_space[nifty_row_idx]
        
        # if augmented_idx is 0, then it's the anchor
        if augmented_idx == 0:
            anchor_newsline = list(nifty_row["news_list"].keys())
            augmented_newslines = nifty_row["augmented_newslines"]
            anchor_id = nifty_row["id"]
            augmented_ids = [nifty_row["id"] + f"_{i + 1}"  for i in range(len(augmented_newslines))]
            similarity = nifty_row["anchor_similarity"]
        else:
            anchor_newsline = nifty_row["augmented_newslines"][augmented_idx - 1]
            augmented_newslines = copy(nifty_row["augmented_newslines"])
            augmented_newslines[augmented_idx - 1] = list(nifty_row["news_list"].keys())
            
            anchor_id = nifty_row["id"] + f"_{augmented_idx}"
            augmented_ids = [nifty_row["id"] + f"_{i + 1}"  for i in range(len(augmented_newslines))]
            augmented_ids[augmented_idx - 1] = nifty_row["id"]
            
            similarity = nifty_row["sim_augmented_matrix"][augmented_idx - 1]
            similarity[augmented_idx - 1] = nifty_row["anchor_similarity"][augmented_idx - 1]
        
        return anchor_newsline, augmented_newslines, anchor_id, augmented_ids, similarity
    
    def create_newsline(self, headline_strings):
        random_prior = random.choice(self.PRIORS)
        
        # Shuffle the headlines
        random.shuffle(headline_strings)
        
        # Strip the headlines of any extra whitespace
        headline_strings = [headline.strip() for headline in headline_strings]
        
        newsline = random_prior + '\n' + '\n'.join(headline_strings)
        
        return newsline

    def tokenize_newsline(self, newsline):
        return self.tokenizer(newsline, return_tensors="pt")
   
class ClassifierDataModule(pl.LightningDataModule):
    def __init__(self):
        super(ClassifierDataModule, self).__init__()
        
        self.train_dataset = ClassifierDataset('train')
        self.val_dataset = ClassifierDataset('val')
        self.test_dataset = ClassifierDataset('test')
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=config.get("training_class", "batch_size"), shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=config.get("training_class", "batch_size"))
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=config.get("training_class", "batch_size"))
   
class LlamaEmbeddingDataModule(pl.LightningDataModule):
    def __init__(self):
        super(LlamaEmbeddingDataModule, self).__init__()
        
        self.train_dataset = LlamaEmbeddingDataset('train')
        self.val_dataset = LlamaEmbeddingDataset('val')
        self.test_dataset = LlamaEmbeddingDataset('test')
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=config.get("training", "batch_size"), shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=config.get("training", "batch_size"))
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=config.get("training", "batch_size"))
    
class SimilaritySpaceDataModule(pl.LightningDataModule):
    def __init__(self):
        super(SimilaritySpaceDataModule, self).__init__()
        
        self.train_dataset = SimilaritySpaceDataset('train')
        self.val_dataset = SimilaritySpaceDataset('val')
        self.test_dataset = SimilaritySpaceDataset('test')
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=config.get("training", "batch_size"), shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=config.get("training", "batch_size"))
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=config.get("training", "batch_size"))
    
class EmbeddingDataModule(pl.LightningDataModule):
    def __init__(self):
        super(EmbeddingDataModule, self).__init__()
        
        self.train_dataset = EmbeddingDataset("train")
        self.val_dataset = EmbeddingDataset("val")
        self.test_dataset = EmbeddingDataset("test")
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=config.get("embedding_training", "batch_size"), shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=config.get("embedding_training", "batch_size"))
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=config.get("embedding_training", "batch_size"))

# Creating the training spaces
# Gets the headline bank and combines them to create a training space
def generate_training_space_old(dataset : list, n_samples : int = None):
    training_space = list()
    
    for i, data_row in enumerate(dataset):
        
        if n_samples is not None and i >= n_samples:
            break
        
        data_row = copy(data_row)
        data_row["augmented_newslines"] = []
        
        current_newsline_strings = create_newsline_strings(dataset)
        
        augmented_newsline_list = []
        for headline_string in current_newsline_strings:
            # Creates the list of headlines, to compile into the newsline
            created_newsline = create_newsline_from_string(dataset, data_row, headline_string)
            augmented_newsline_list.append(created_newsline)
            
        base_newsline = list(data_row["news_list"].keys())
            
        anchor_similarity, sim_augmented_matrix = calculate_similarity_from_anchor_matrix(current_newsline_strings)
                    
        data_row["anchor_similarity"] = anchor_similarity
        data_row["sim_augmented_matrix"] = sim_augmented_matrix
        data_row["augmented_newslines"] = augmented_newsline_list
            
        training_space.append(data_row)
    
    return training_space

def generate_training_space(headline_bank, n_samples):
    compare_batch_size = config.get("data_gen", "compare_batch_size")
    
    training_space = list()
    
    for i, data in enumerate(headline_bank):
        if n_samples is not None and i >= n_samples:
            break
        
        data_row = copy(data)
        data_row["augmented_newslines"] = []
        
        if config.get("dataset", "dataset") == "bigdata22":
            headlines_list = str(data["text"]).split("\n")
            headlines_list = bigdata_clean_headlines_list(headlines_list)
        elif config.get("dataset", "dataset") == "imdb":
            headlines_list = str(data["query"]).split("\n")
        else:
            headlines_list = str(data["news"]).split("\n")
            
        data_row["base_news"] = headlines_list
        
        current_newsline_strings = create_newsline_strings(data, headline_bank)
        sim_scores = compute_weighted_similarity_scores(current_newsline_strings)
        
        for headline_string, sim_score in zip(current_newsline_strings, sim_scores):
            # Creates the list of headlines, to compile into the newsline
            created_newsline = create_newsline_from_string(headline_bank, data_row, headline_string)
            data_row["augmented_newslines"].append({
                "newsline": created_newsline,
                "similarity": sim_score[1],
                "newsline_string": sim_score[0]
            })
            
        training_space.append(data_row)
        
    return training_space

def calculate_similarity_from_anchor_matrix(newsline_strings):
    anchor_distances = [calculate_similarity_from_anchor(newsline_string) for newsline_string in newsline_strings]
    
    # calculate the distances between each newsline
    augmented_matrix = []
    for i in range(len(newsline_strings)):
        augmented_distances = []
        for j in range(len(newsline_strings)):
            
            if i == j:
                augmented_distances.append(1.0)
                continue
            
            augmented_similarities = calculate_simliarity_between_newslines(newsline_strings[i], newsline_strings[j])
            augmented_distances.append(augmented_similarities)
        augmented_matrix.append(augmented_distances)
        
    return anchor_distances, augmented_matrix
    
    
def compute_weighted_similarity_scores(strings):
    # Define point values for each letter
    point_values = {
        'r': 1.0,
        'a': 0.5,
        'n': 0.0,
        'o': 0.0,
        'd': 0.0  # Assuming 'd' has 0 points
    }
    
    # Step 1: Compute the total score for each string
    scores = []
    for s in strings:
        total_score = sum(point_values.get(c, 0.0) for c in s)
        scores.append((s, total_score))
    
    # Step 2: Normalize the scores to the range [0, 1]
    max_score = max(score for _, score in scores) if scores else 1.0  # Avoid division by 0
    normalized_scores = [(s, score / max_score) for s, score in scores]
    
    # Step 3: Apply a logarithmic weighting to each normalized score
    weighted_scores = []
    for s, score in normalized_scores:
        # Adjusted weighting function using logarithm
        # Adding a small epsilon to avoid math domain error when score is 0
        epsilon = 1e-10
        weighted_score = math.log1p(score * (math.e - 1)) / math.log(math.e)
        weighted_scores.append((s, weighted_score))
    
    return weighted_scores

def calculate_simliarity_between_newslines(newsline_string1, newsline_string2):
    # Calculate the distance between two newsline strings
    
    # d -> delete
    # o -> other headline
    # r -> rephrased
    # a -> ablation
    # n -> negative
    
    type_scores_matrix = {
        "d": {"d": 0.5, "o": 0.5, "r": 0.5, "a": 0.5, "n": 0.5},
        "o": {"d": 0.5, "o": 0.0, "r": 0.0, "a": 0.0, "n": 0.0},
        "r": {"d": 0.5, "o": 0.0, "r": 1.0, "a": 0.5, "n": 0.0},
        "a": {"d": 0.5, "o": 0.0, "r": 0.5, "a": 0.4, "n": 0.5},
        "n": {"d": 0.5, "o": 0.0, "r": 0.0, "a": 0.25, "n": 1.0},
    }
    
    similarity = 0.0
    for c1, c2 in zip(newsline_string1, newsline_string2):
        similarity += type_scores_matrix[c1][c2]
        
    return similarity / len(newsline_string1)


def calculate_similarity_from_anchor(headline_string):
    type_scores = {
        "o": 0.0,
        "r": 1.0,
        "a": 0.5,
        "n": 0.0,
        "d": 0.5
    }
    
    score_list = [type_scores[c] for c in headline_string]
    
    return sum(score_list) / len(score_list)

def create_newsline_from_string(dataset : list, data_row: dict, headline_string : str):
    newsline_list = []
    
    headlines = list(data_row["news_list"].keys())
    
    for i, c in enumerate(headline_string):
        if c == "p":
            if len(headlines) == 0:
                continue
            
            random_headline = headlines[0]
            headlines.remove(random_headline)
            
            newsline_list.append(random_headline)            
        elif c == "o":
            if len(headlines) == 0:
                continue
            
            # Other headline: get random headline from the dataset
            random_headline = headlines[0]
            headlines.remove(random_headline)

            random_day = random.choice(dataset)
            random_headline = random.choice(list(random_day["news_list"].keys()))

            newsline_list.append(random_headline)
            
            
        elif c == "r":
            if len(headlines) == 0:
                continue
            
            random_headline = headlines[0]
            random_rephrased = random.choice(data_row["news_list"][random_headline]["rephrased"])
            headlines.remove(random_headline)
            
            newsline_list.append(random_rephrased)
        elif c == "a":
            if len(headlines) == 0:
                continue
            
            # Ablation headline
            random_headline = headlines[0]
            
            # TEMP: Fixing typo in the data
            if "ablation" in data_row["news_list"][random_headline]:
                random_abalation = random.choice(data_row["news_list"][random_headline]["ablation"])
            elif "abalation" in data_row["news_list"][random_headline]:
                random_abalation = random.choice(data_row["news_list"][random_headline]["abalation"])
                
            headlines.remove(random_headline)
            
            newsline_list.append(random_abalation)
        elif c == "n":
            if len(headlines) == 0:
                continue
            
            # Negative headline
            random_headline = headlines[0]
            random_negative = random.choice(data_row["news_list"][random_headline]["negative"])
            headlines.remove(random_headline)
            
            newsline_list.append(random_negative)
        elif c == "d":
            if len(headlines) == 0:
                continue
            # Delete a headline
            random_headline = headlines[0]
            headlines.remove(random_headline)
            
    
    return newsline_list  
    

def create_newsline_strings(dataset: list, headline_bank: list):
    compare_batch_size = config.get("data_gen", "compare_batch_size")
    ratios = config.get("data_gen", "ratios")
    
    letter_ratios = {
        "o": ratios["other_headline"],
        "r": ratios["rephrased"],
        "a": ratios["abalation"],
        "n": ratios["negative"],
        "d": ratios["delete"]
    }

    # Collect all the news_list lengths from the headline_bank
    news_list_lengths = [len(item["news_list"]) for item in headline_bank]

    # Now sample bucket sizes from the news_list_lengths
    headline_bucket_sizes = [random.choice(news_list_lengths) for _ in range(compare_batch_size)]
    
    total_string_length = sum(headline_bucket_sizes)
    
    # Calculate the count for each letter based on the total length and their ratio
    letter_counts = {letter: round(total_string_length * ratio) for letter, ratio in letter_ratios.items()}

    # Create the ordered string by adding the corresponding number of each letter
    ordered_string = ''.join([letter * count for letter, count in letter_counts.items()])

    # Creating the semi-random string
    n_string_shuffles = int(len(ordered_string) * config.get("data_gen", "string_shuffle_ratio"))
    
    # Convert the letter_ratios dict into a list of letters and a corresponding list of weights
    letters = list(letter_ratios.keys())
    weights = list(letter_ratios.values())
    
    for _ in range(n_string_shuffles):
        # Do a single character random swap:
        swap_index = random.randint(0, len(ordered_string) - 1)
        swap_char = random.choices(letters, weights=weights, k=1)[0]  # Choose according to the distribution
        
        ordered_string = ordered_string[:swap_index] + swap_char + ordered_string[swap_index + 1:]
    
    # Split the string into buckets of the correct size
    buckets = []
    start = 0
    for size in headline_bucket_sizes:
        buckets.append(ordered_string[start:start + size])
        start += size

    return buckets



# Loading the NIFTY Dataset
def load_nifty(split : str):
    assert split in ['train', 'val', 'test'], 'Invalid split'
    
    nifty_path = config.get("dataset", "nifty_path")
    split_path = os.path.join(nifty_path, f"{split}.jsonl")

    
    with open(split_path, 'r') as file:
        lines = file.readlines()
        
    data = [json.loads(line) for line in lines]
    
    n_samples = config.get("data_gen", "n_samples")[split]
    if n_samples is not None:
        data = data[:n_samples]
    
    return data

# Loading the NIFTY Dataset
def load_bigdata22(split : str):
    assert split in ['train', 'val', 'test'], 'Invalid split'
    
    nifty_path = config.get("dataset", "nifty_path")
    split_path = os.path.join(nifty_path, f"{split}.jsonl")

    
    with open(split_path, 'r') as file:
        lines = file.readlines()
        
    data = [json.loads(line) for line in lines]
    
    n_samples = config.get("data_gen", "n_samples")[split]
    if n_samples is not None:
        data = data[:n_samples]
    
    return data

def load_imdb(split : str):
    assert split in ['train', 'val', 'test'], 'Invalid split'
    
    imdb_path = config.get("dataset", "nifty_path")
    split_path = os.path.join(imdb_path, f"{split}.jsonl")

    
    with open(split_path, 'r') as file:
        lines = file.readlines()
        
    data = [json.loads(line) for line in lines]
    
    n_samples = config.get("data_gen", "n_samples")[split]
    if n_samples is not None:
        data = data[:n_samples]
    
    return data
    
def load_headlines_bank(split : str):
    headlines_bank_path = config.get("data_gen", "headlines_bank_path")[split]
    
    with open(headlines_bank_path, 'r') as file:
        headlines_bank = json.load(file)
        
    return headlines_bank

def create_headline_bank(dataset: list, split: str):
    from .headline_llm import generate_ablation, generate_alternatives, generate_negative
    # Splits prompts into component parts. Splits string of headlines into a list, and gives a structure that we can use to generate alternatives

    headlines_bank_path = config.get("data_gen", "headlines_bank_path")[split]

    n_kept_headlines = config.get("data_gen", "n_kept_headlines")
    n_rephrased = config.get("data_gen", "n_rephrased")
    n_abalations = config.get("data_gen", "n_abalations")
    n_negatives = config.get("data_gen", "n_negatives")
    
    total_headline_count = 0
    
    # Load existing headlines_bank if it exists
    if os.path.exists(headlines_bank_path):
        with open(headlines_bank_path, 'r') as file:
            headlines_bank = json.load(file)
        print(f"Loaded existing headlines bank with {len(headlines_bank)} items.")
    else:
        headlines_bank = []
        print("No existing headlines bank found. Starting fresh.")

    # Create a set of already processed IDs for quick lookup
    processed_ids = set(item['id'] for item in headlines_bank)

    for i, item in enumerate(dataset):
        item_id = item['id']
        
        if item_id in processed_ids:
            print(f"Skipping already processed item with ID: {item_id}")
            continue  # Skip this item as it's already processed

        if config.get("dataset", "dataset") == "bigdata22":
            headlines_list = str(item["text"]).split("\n")
            headlines_list = bigdata_clean_headlines_list(headlines_list)
        elif config.get("dataset", "dataset") == "imdb":
            headlines_list = str(item["query"]).split("\n")
            # If there are more than 5 randomly select 5
            if len(headlines_list) > 7:
                headlines_list = random.sample(headlines_list, 7)
        else:
            headlines_list = str(item["news"]).split("\n")
        
        
        # If there are more than n_kept_headlines headlines, randomly sample n_kept_headlines
        if len(headlines_list) > n_kept_headlines:
            headlines_list = random.sample(headlines_list, n_kept_headlines)
            
        current_headlines = copy(item)
        current_headlines["news_list"] = {}
        
        for headline in headlines_list:
            rephrased_alternatives = generate_alternatives(headline, n_rephrased)
            ablation_alternatives = generate_ablation(headline, n_abalations)
            negative_alternatives = generate_negative(headline, n_negatives)
            
            current_headlines["news_list"][headline] = {
                "rephrased": rephrased_alternatives,
                "ablation": ablation_alternatives,
                "negative": negative_alternatives
            }
        
        headlines_bank.append(current_headlines)
        processed_ids.add(item_id)  # Add the new ID to the set
        
        total_headline_count += len(headlines_list)
        
        # Save the headlines bank after processing each item
        if i % 10 == 0 or i == len(dataset) - 1:
            with open(headlines_bank_path, 'w') as file:
                json.dump(headlines_bank, file, indent=4)
            
        print(f"Processed item {i + 1}/{len(dataset)}: {split} headlines: {len(headlines_list)}", flush=True)
    
    print(f"Total headlines processed: {total_headline_count}")
    
    with open(headlines_bank_path, 'w') as file:
        json.dump(headlines_bank, file, indent=4)
        
    # Delete all torch cuda models to free up memory
    torch.cuda.empty_cache()
        
    return headlines_bank

def bigdata_clean_headlines_list(headlines_list):
    if len(headlines_list) <= 12:
        return []
    
    # The first 12 text lines do not contrain headlines
    headlines_list = headlines_list[12:]
    
    # Remove the dates from each of the headlines
    headlines_out = []
    for i, headline in enumerate(headlines_list):
        # Remove the date from the headline
        
        
        headline = headline.split(":", 1)
        
        if len(headline) != 2:
            continue
        
        headline = headline[1]
        
        # Trim the left and right whitespace
        headline = headline.strip()
        
        headlines_out.append(headline)
    
    return headlines_out