import torch
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
import torch
from collections import defaultdict, Counter
import pandas as pd
import numpy as np
from huggingface_hub import login
from tqdm import tqdm
import os
import json
from sklearn.decomposition import PCA
from openTSNE import TSNE
import random

from .config import config  # Now this import shoudl work
from ..util import load_model
from .util import EmbeddingKVStore


class EmbeddingDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        
        self.batch_size = config.get('training', 'batch_size')
        self.num_workers = config.get('training', 'num_workers')

    def setup(self, stage=None):
        # Initialize the datasets
        self.train_dataset = EmbeddingDataset('train')
        self.val_dataset = EmbeddingDataset('valid')
        self.test_dataset = EmbeddingDataset('test')

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle = True, drop_last=True)
    
    
class EmbeddingDataset(Dataset):
    def __init__(self, split : str):
        super().__init__()
        assert split in ['train', 'valid', 'test']
        
        self.split = split
        
        model_name = config.get('dataset', 'pretrained_model')
        self.embedding_store = EmbeddingKVStore(model_name)
        
        self.organized_embeddings = self.create_organized_embeddings()
        
        if split == "train":
            self.organized_embeddings = self.create_organized_embeddings()
        else:
            self.organized_embeddings = None
        
        all_data = self.load_embedding_data()
        
        self.metadata = all_data['metadata']
        self.embeddings = all_data['embeddings']
        
        # Create a mapping from index to id
        self.idx2id = tuple(self.embeddings.keys())
            
    def __len__(self):
        # Return the total number of samples
        return len(self.embeddings)

    def __getitem__(self, idx):
        # Retrieve the sample at index idx and its label
        nifty_id = self.idx2id[idx]
        
        embedding = self.embeddings[nifty_id]
        gt = embedding['gt']
        
        positive_pair_id, negative_pair_id = self.get_random_positive_negative_pairs(nifty_id, gt)
        
        # Retrive the embeddings for the positive and negative pairs
        positive_pair = self.embeddings[positive_pair_id]['vals']
        negative_pair = self.embeddings[negative_pair_id]['vals']
        
        return {
            "anchor": np.array(embedding['vals'], dtype=np.float32),
            "positive_pair": np.array(positive_pair, dtype=np.float32),
            "negative_pair": np.array(negative_pair, dtype=np.float32),
            "gt": gt
        }
        
    def load_llm_model(self):
        model = config.get('dataset', 'pretrained_model')
        model, tokenizer = load_model(model)
        
        return model, tokenizer
    
    def create_organized_embeddings(self):
        # Loading the correct LLM for datageneration
        #self.llm_model, self.tokenizer = self.load_llm_model()
        
        # Loads the embeddings and organizes them by their ground truth label
        ratio_base_case = config.get('augmentation', 'base_case')
        ratio_shuffled_base = config.get('augmentation', 'shuffled_base')
        ratio_augmented_base = config.get('augmentation', 'augmented_base')
        ratio_positive_case = config.get('augmentation', 'positive_mixes')
        ratio_negative_case = config.get('augmentation', 'negative_mixes')
        
        nifty_dataset = self.load_prompt_data(remove_neutral = False)
        
        self.batches = []
        
        for anchor_prompt in nifty_dataset:
            curr_batch = {
                "anchor": anchor_prompt,
                "augmented_base": [],
                "positive": [],
                "negative": [],
                "mixed": []
            }
            
            # TODO actually create the batches that are going to be put into the dataloader
            
            batch = []
            
            batch += self.create_base_case(anchor_prompt, ratio_base_case)
            
            batch += self.create_shuffled_base(anchor_prompt, nifty_dataset, ratio_shuffled_base)
            
            #batch += self.create_augmented_base(anchor_prompt, nifty_dataset, ratio_augmented_base)
            
            #batch += self.create_positive_mixes(anchor_prompt, nifty_dataset, ratio_positive_case)
            
            #batch += self.create_negative_mixes(anchor_prompt, nifty_dataset, ratio_negative_case)
            
            self.batches.append(batch)
            
        # After we have all of the batches, we can create the organized embeddings and save them
        self.create_and_save_embedding_batches()
        
    def create_and_save_embedding_batches(self):
        batches = self.batches
        
        # Load the model and tokenizer
        model, tokenizer = self.load_llm_model()
        
        correct_count = 0
        total_count = 0
        
        for batch in batches:
            for save_data in batch:
                #embedding = self.get_all_layer_embedding(model, tokenizer, save_data["prompt"])
                
                correct_count += self.print_model_outputs(model, tokenizer, save_data["prompt"], save_data["gt_vals"][0])
                total_count += 1
                
                # For now, we only save the last layer
                #embedding = embedding[-1]
                
                #save_data["embedding"] = embedding.tolist()
                
                #self.embedding_store.save(save_data)
                
        print(f"Correct count: {correct_count}")
        print(f"Total count: {total_count}")
        print(f"Accuracy: {correct_count / total_count}")
        
                
    # Creation of the different types of batches
    def create_base_case(self, anchor_prompt, ratio):
        save_data = self.get_prompt_save_data(anchor_prompt["conversations"][0]['value'], [anchor_prompt], "base_case")
        
        return [save_data] * ratio
    
    def create_shuffled_base(self, anchor_prompt, nifty_dataset, ratio):
        # Creates the shuffled base case
        prior, csv_data, news, answer = NIFTYAugmentation.get_prompt_pieces(anchor_prompt, use_random_prior=True)
        split_news = NIFTYAugmentation.split_news(news)
        
        base_cases = []
        
        for i in range(ratio):
            # Randomly shuffle the news
            shuffled_news = np.random.permutation(split_news)
            
            # Reconstruct the news
            shuffled_news = '\n'.join(shuffled_news)
            
            # Get a new random prior
            prior = NIFTYAugmentation.get_random_prior()
            
            # Reconstruct the prompt
            shuffled_prompt = f"{prior}\n{csv_data}\n{shuffled_news}\n{answer}"
            
            # Creates data to be saved given the pertinan information
            save_data = self.get_prompt_save_data(shuffled_prompt, [anchor_prompt], "shuffled_base")
            
            # Get the embedding for the shuffled prompt
            base_cases.append(save_data)
            
        return base_cases
    
    def create_augmented_base(self, anchor_prompt, nifty_dataset, ratio):
        pass
    
    def create_positive_mixes(self, anchor_prompt, nifty_dataset, ratio):
        pass
    
    def create_negative_mixes(self, anchor_prompt, nifty_dataset, ratio):
        pass
    
    def get_prompt_save_data(self, save_prompt: str, anchor_prompts_used: list, created_from: str):
        '''
        Given a prompt and the other prompts use, this function will return the save data for the prompt
        '''
        
        id_list = [prompt['id'] for prompt in anchor_prompts_used]
        gt_list = [prompt['label'] for prompt in anchor_prompts_used]
        
        rise_fall_scores = [0.0 if gt == "Rise" else 1.0 for gt in gt_list]
        rise_fall_score = np.mean(rise_fall_scores)
        
        return {
            "prompt": save_prompt,
            "anchor_ids": id_list,
            "gt_vals": gt_list,
            "rise_fall_scores": rise_fall_scores,
            "rise_fall_score": rise_fall_score,
            "created_from": created_from
        }
        
    # Implement this function to return the positive and negative pairs
    # Takes the index of the current sample and the ground truth label
    def get_random_positive_negative_pairs(self, id, gt):
        legal_positives = set(self.organized_embeddings[gt]).difference(set(id))
        
        legal_negative_classes = [i for i in range(3) if i != gt]
        legal_negatives = []
        for neg_class in legal_negative_classes:
            legal_negatives += self.organized_embeddings[neg_class]
        
        positive_pair = np.random.choice(list(legal_positives))
        negative_pair = np.random.choice(list(legal_negatives))
        
        return positive_pair, negative_pair
        
    def load_embedding_data(self):
        nifty_path = config.get('dataset', 'nifty_path')
        
        model_name = config.get('model', 'pretrained_model')
        embedding_pooling_strategy = config.get('model', 'embedding_pooling_strategy')
        
        split_cache_path = os.path.join(nifty_path, f'{self.split}_{model_name}_{embedding_pooling_strategy}_nifty.pkl')
        
        if os.path.exists(split_cache_path):
            return torch.load(split_cache_path)
        
        # Login to the Hugging Face Hub
        login("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")
            
        # Generating embeddings
        prompt_data = self.load_prompt_data()
        embeddings = self.generate_embeddings(model_name, prompt_data)
        
        # Save the embedding dictionary to the cache path
        torch.save(embeddings, split_cache_path)
            
        return embeddings
        
    def load_prompt_data(self, remove_neutral = False):
        nifty_path = config.get('dataset', 'nifty_path')
        split_path = os.path.join(nifty_path, f'{self.split}.jsonl')
        
        # Load the data from the split_path
        with open(split_path, 'r') as file:
            data = [json.loads(line) for line in file]
            
        if remove_neutral:
            data = [row for row in data if row['label'] != 'Neutral']
            
        return data
    
    ####### Embedding generation scripts #######

    def get_all_layer_embedding(self, model, tokenizer, text):
        # Tokenize the input text
        inputs = tokenizer(text, return_tensors="pt")
        
        # Get model outputs
        with torch.no_grad():
            # Convert the inputs to the correct device
            inputs = {key: value.to(model.device) for key, value in inputs.items()}
            
            outputs = model(**inputs, output_hidden_states=True)
    
        # Extract embeddings (hidden states of the last layer)
        embedding_pooling_strategy = config.get('dataset', 'embedding_pooling_strategy')
        
        # Put hidden
        hidden_states = torch.stack(outputs.hidden_states).squeeze().detach().cpu().numpy()
        
        if embedding_pooling_strategy == 'mean':
            embeddings = np.mean(hidden_states, axis=1)
        elif embedding_pooling_strategy == 'last':
            embeddings = hidden_states[:, -1, :]
       
        return embeddings
    
    def print_model_outputs(self, model, tokenizer, text, gt_text):
        # Tokenize the input text
        inputs = tokenizer(text, return_tensors="pt")
        
        # Get model outputs
        with torch.no_grad():
            # Convert the inputs to the correct device
            inputs = {key: value.to(model.device) for key, value in inputs.items()}
            
            # Get the attention mask and input_ids
            attention_mask = inputs['attention_mask']
            input_ids = inputs['input_ids']
            
            # Generate the model's response, specifying the starting point for generation
            generated_ids = model.generate(
                input_ids, 
                max_new_tokens=4, 
                num_return_sequences=1, 
                attention_mask=attention_mask,
                pad_token_id=tokenizer.pad_token_id, 
                eos_token_id=tokenizer.eos_token_id, 
                bos_token_id=tokenizer.bos_token_id
            )
            
            # Decode only the newly generated tokens (ignore the input tokens)
            response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
    
            print(f">>>> Ground truth: {gt_text}")
            print(f">>>> Model's response: {response}")
            
            # Measure if the model's response is correct
            correct = (gt_text.lower() in response.lower())
            
            if not correct:
                print(">>>> INCORRECT")
            else:
                print(">>>> CORRECT")
                
            return correct
            


    def generate_embeddings(self, model_name, dataset):
        model, tokenizer = load_model(model_name)
        
        # Set model to eval mode
        model.eval()
        
        embedding_data = {
            "metadata": {
                "model_name": model_name,
                "embedding_pooling_strategy": config.get('model', 'embedding_pooling_strategy'),
                "dataset": "nifty",
                "split": self.split
            },
            "embeddings": {}
        }
        
        for i, row in enumerate(tqdm(dataset, desc=f"Generating embeddings for {model_name}:nifty")):
            nifty_id = row["id"]
            prompt = row["conversations"][0]['value']
                
            embedding = self.get_all_layer_embedding(model, tokenizer, prompt)
            
            if row["label"] == "Rise":
                gt = 0
            elif row["label"] == "Neutral":
                gt = 1
            elif row["label"] == "Fall":
                gt = 2
            else:
                raise ValueError(f"Invalid label: {row['label']}")
            
            # convert the embedding to numpy array, float16
            embedding = embedding.astype(np.float16)
            
            embedding_data["embeddings"][nifty_id] = {
                "vals" : embedding,
                "gt" : gt
            }
            
            torch.cuda.empty_cache()
            
        return embedding_data

class NIFTYAugmentation:
    @staticmethod
    def split_news(prompt : str):
        return prompt.split('\n')
    
    @staticmethod
    def get_prompt_pieces(prompt : str, use_random_prior = False):
        full_prompt = prompt["conversations"][0]['value']
        news = prompt["news"]
        
        split_news = full_prompt.split('\n')
        
        prior = split_news[0]
        csv_data = split_news[:-2]
        news = split_news[-2]
        answer = split_news[-1]
        
        # Reattached csv_data
        csv_data = '\n'.join(csv_data)
        
        if use_random_prior:
            prior = NIFTYAugmentation.get_random_prior()
        
        return prior, csv_data, news, answer
    
    PRIORS = [
        "Examine market data and news headlines dated DATE to project the direction of the $SPY index. If the expected shift is below PCTCHANGE, finalize with 'Neutral'. Present a response as 'Fall', 'Rise', or 'Neutral', along with the anticipated percentage change in a newline.",
        "To predict the $SPY index's direction, analyze market data and news headlines from DATE. If the expected alteration is less than PCTCHANGE, end with 'Neutral'. Offer a reply as 'Rise', 'Neutral', or 'Fall', coupled with the forecasted percentage change in a newline.",
        "Forecast the $SPY index's movement by scrutinizing market data and news headlines from DATE. If the expected adjustment is below PCTCHANGE, conclude with 'Neutral'. Deliver a response as 'Fall', 'Rise', or 'Neutral', alongside the predicted percentage change in a newline.",
        "Project the $SPY index's trajectory by assessing market data and news headlines from DATE. If the envisaged modification is under PCTCHANGE, end with 'Neutral'. Furnish a response as 'Neutral', 'Fall', or 'Rise', including the forecasted percentage change in a newline.",
        "Analyze market data and news headlines dated DATE to forecast the direction of the $SPY index. Conclude with 'Neutral' if the projected change is below PCTCHANGE. Provide a response as 'Rise', 'Fall', or 'Neutral', along with the anticipated percentage change in a newline.",
        "Assess market data and news headlines from DATE to predict the $SPY index's direction. Finalize with 'Neutral' if the expected shift is less than PCTCHANGE. Offer a reply as 'Fall', 'Neutral', or 'Rise', along with the forecasted percentage change in a newline.",
        "Predict the $SPY index's movement by examining market data and news headlines from DATE. Conclude with 'Neutral' if the anticipated adjustment is under PCTCHANGE. Present a response as 'Neutral', 'Fall', or 'Rise', with the predicted percentage change in a newline.",
        "Anticipate the direction of the $SPY index by analyzing market data and news headlines from DATE. End with 'Neutral' if the expected change is below PCTCHANGE. Supply a response as 'Rise', 'Neutral', or 'Fall', along with the forecasted percentage change in a newline.",
        "Examine market data and news headlines dated DATE to project the $SPY index's direction. Conclude with 'Neutral' if the envisaged shift is under PCTCHANGE. Deliver a response as 'Neutral', 'Rise', or 'Fall', along with the predicted percentage change in a newline.",
        "To predict the direction of the $SPY index, analyze market data and news headlines from DATE. End with 'Neutral' if the expected alteration is below PCTCHANGE. Provide a reply as 'Fall', 'Neutral', or 'Rise', along with the anticipated percentage change in a newline.",
        "Forecast the $SPY index's trajectory by scrutinizing market data and news headlines from DATE. Conclude with 'Neutral' if the expected adjustment is less than PCTCHANGE. Furnish a response as 'Rise', 'Fall', or 'Neutral', including the forecasted percentage change in a newline.",
        "Project the $SPY index's movement by assessing market data and news headlines from DATE. Conclude with 'Neutral' if the envisaged modification is under PCTCHANGE. Offer a reply as 'Neutral', 'Rise', or 'Fall', along with the predicted percentage change in a newline.",
        "Analyze market data and news headlines dated DATE to forecast the direction of the $SPY index. Deliver a response as 'Fall', 'Neutral', or 'Rise', along with the anticipated percentage change in a newline. Conclude with 'Neutral' if the projected change is below PCTCHANGE.",
        "Assess market data and news headlines from DATE to predict the $SPY index's direction. Offer a reply as 'Rise', 'Fall', or 'Neutral', along with the forecasted percentage change in a newline. Finalize with 'Neutral' if the expected shift is less than PCTCHANGE.",
        "Predict the $SPY index's movement by examining market data and news headlines from DATE. Supply a response as 'Neutral', 'Rise', or 'Fall', with the predicted percentage change in a newline. Conclude with 'Neutral' if the anticipated adjustment is under PCTCHANGE.",
        "Anticipate the direction of the $SPY index by analyzing market data and news headlines from DATE. Deliver a response as 'Fall', 'Neutral', or 'Rise', along with the forecasted percentage change in a newline. End with 'Neutral' if the expected change is below PCTCHANGE.",
        "Examine market data and news headlines dated DATE to project the $SPY index's direction. Present a response as 'Rise', 'Fall', or 'Neutral', along with the predicted percentage change in a newline. Conclude with 'Neutral' if the envisaged shift is under PCTCHANGE.",
        "To predict the direction of the $SPY index, analyze market data and news headlines from DATE. Provide a reply as 'Neutral', 'Rise', or 'Fall', along with the anticipated percentage change in a newline. End with 'Neutral' if the expected alteration is below PCTCHANGE.",
        "Forecast the $SPY index's trajectory by scrutinizing market data and news headlines from DATE. Furnish a response as 'Fall', 'Neutral', or 'Rise', including the forecasted percentage change in a newline. Conclude with 'Neutral' if the expected adjustment is less than PCTCHANGE.",
        "Project the $SPY index's movement by assessing market data and news headlines from DATE. Offer a reply as 'Rise', 'Fall', or 'Neutral', along with the predicted percentage change in a newline. Conclude with 'Neutral' if the envisaged modification is under PCTCHANGE."
    ]
    
    @staticmethod
    def get_random_prior():
        '''
        Returns a random prior
        '''
        
        random_prior = random.choice(NIFTYAugmentation.PRIORS)
        
        return random_prior

    
    
######## DIMENSION REDUCTION SCRIPTS ########
    
def reduce_dimensions(embeddings, n_components = 2, method = 'pca'):
    # Reduce dimensions using PCA
    if method == 'pca':
        pca = PCA(n_components=n_components)
        pca_trans = pca.fit_transform(embeddings)
        
        return pca_trans 
    elif method == 'tsne':
        tsne = TSNE(
            n_components=n_components, 
            n_jobs=-1,  # Use all available cores
            perplexity=5,
            initialization="pca",  # Initialize with PCA to speed up
            metric="euclidean"  # Default metric, can be changed based on your data
        )
        tsne_trans = tsne.fit(embeddings)
        
        return tsne_trans