from nesim.configs import NesimConfig
from ..bimt.loss import BIMTConfig
from ..losses.cross_layer_correlation.loss import CrossLayerCorrelationLossConfig
from nesim.losses.nesim_loss import (
    NesimConfig,
    NesimLoss,
)
import time
from lightning.pytorch import seed_everything
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoConfig
from datasets import load_dataset
import wandb
import os
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import datasets

possible_dataset_names = ["wikipedia", "openwebtext"]


def get_dataset_from_name(name: str, dataset_cache_dir: str):
    assert name in possible_dataset_names
    assert os.path.exists(dataset_cache_dir)

    if name == "wikipedia":
        print(f"Loading dataset: {name}")
        return load_dataset("wikipedia", "20220301.en", cache_dir=dataset_cache_dir)
    else:
        print(f"Loading dataset: {name}")
        return load_dataset("Skylion007/openwebtext", cache_dir=dataset_cache_dir)


def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight.data)

def get_untrained_model_and_tokenizer(name="EleutherAI/gpt-neo-1.3B"):
    """
    Loads an untrained model
    """
    tokenizer = AutoTokenizer.from_pretrained(name)
    config = AutoConfig.from_pretrained(name)
    model = AutoModelForCausalLM.from_config(config)

    return model, tokenizer

def get_checkpoint(checkpoint_filename: str = None, device: str = "cpu", name = "EleutherAI/gpt-neo-1.3B"):
    """
    string = checkpoint
    None = random weights
    "pretrained" = original pre-trained weights
    """
    model, tokenizer = get_untrained_model_and_tokenizer(name=name)
    model.eval().to(device)
    model.generation_config.pad_token_id = tokenizer.pad_token_id

    if checkpoint_filename is not None and checkpoint_filename != "pretrained":
        model.load_state_dict(
            torch.load(checkpoint_filename, weights_only=True),
        )
    elif checkpoint_filename == None:
        model = model.apply(weights_init)
    elif checkpoint_filename == "pretrained":
        model = model.from_pretrained(name)
        model.eval().to(device)
        pass ## do nothing
    else:
        raise ValueError(f"Invalid checkpoint_filename: {checkpoint_filename}")
    return model, tokenizer


class GPTNeoTrainingConfig(object):
    def __init__(
        self,
        training_arguments: TrainingArguments,
        nesim_config: NesimConfig,
        bimt_config: Union[None, BIMTConfig],
        cross_layer_correlation_loss_config: Union[
            None, CrossLayerCorrelationLossConfig
        ],
        wandb_log: bool,
        dataset_cache_dir: str,
        dataset_name: str,
        tokenized_dataset_path: Union[None, str] = None,
        context_length: int = 128,
        apply_nesim_every_n_steps=10,
        neighbourhood_cosine_similarity_loss_lower_bound=0.1,
        resume_wandb_id: Union[None, str] = None,
        sample_prompts: list = [
            "The President of the United States is",
            "An apple a day",
            "The Eiffel Tower was originally meant to be a"
        ],
        num_sample_completion_tokens: int = 10,
        generate_sample_text_every_n_steps: int = 1000
    ):
        self.dataset_cache_dir = dataset_cache_dir
        self.training_arguments = training_arguments
        self.context_length = context_length
        self.nesim_config = nesim_config
        self.bimt_config = bimt_config
        self.cross_layer_correlation_loss_config = cross_layer_correlation_loss_config
        self.wandb_log = wandb_log
        self.apply_nesim_every_n_steps = apply_nesim_every_n_steps
        self.neighbourhood_cosine_similarity_loss_lower_bound = (
            neighbourhood_cosine_similarity_loss_lower_bound
        )
        self.tokenized_dataset_path = tokenized_dataset_path
        self.resume_wandb_id = resume_wandb_id
        self.dataset_name = dataset_name
        self.sample_prompts = sample_prompts
        self.num_sample_completion_tokens = num_sample_completion_tokens
        self.generate_sample_text_every_n_steps = generate_sample_text_every_n_steps
        self.validate()

    def validate(self):
        assert (
            os.path.exists(self.dataset_cache_dir) == True
        ), f"Invalid dataset_cache_dir: {self.dataset_cache_dir}"
        assert self.num_sample_completion_tokens > 1


class GPTNeoTraining:
    def __init__(self, config: GPTNeoTrainingConfig):
        assert isinstance(config, GPTNeoTrainingConfig)
        self.config = config

    def get_model_and_tokenizer(self, name="EleutherAI/gpt-neo-1.3B"):
        tokenizer = AutoTokenizer.from_pretrained(name)
        config = AutoConfig.from_pretrained(name)
        model = AutoModelForCausalLM.from_config(config)
        model = model.apply(weights_init)

        return model, tokenizer
    
    def generate_sample_text(self, model, tokenizer, prompt, num_tokens):
        try:
            model.module.generation_config.pad_token_id = tokenizer.pad_token_id
            input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.module.device)
            output = model.module.generate(input_ids, max_new_tokens=num_tokens, do_sample=True, top_k=50, top_p=0.95)

        except AttributeError:
            """
            patch to handle case of single GPU training
            """
            model.generation_config.pad_token_id = tokenizer.pad_token_id
            input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
            output = model.generate(input_ids, max_new_tokens=num_tokens, do_sample=True, top_k=50, top_p=0.95)
        
        return tokenizer.decode(output[0], skip_special_tokens=True)

    def run(self):
        print("IN RUN... getting model and tokenizer")
        # seed everything to make training runs deterministic
        seed_everything(0)

        model, tokenizer = self.get_model_and_tokenizer(name="EleutherAI/gpt-neo-1.3B")
        print("got model and tokenizer...")
        def tokenize(element):
            outputs = tokenizer(
                element["text"],
                truncation=True,
                max_length=self.config.context_length,
                return_overflowing_tokens=True,
                return_length=True,
            )
            input_batch = []
            for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
                if length == self.config.context_length:
                    input_batch.append(input_ids)
            return {"input_ids": input_batch}

        if self.config.tokenized_dataset_path is None:
            dataset = get_dataset_from_name(
                name=self.config.dataset_name,
                dataset_cache_dir=self.config.dataset_cache_dir,
            )
            tokenized_dataset = dataset.map(
                tokenize, batched=True, remove_columns=dataset["train"].column_names
            )
        else:
            if os.path.exists(self.config.tokenized_dataset_path):
                print(
                    f"Loading pre-tokenized dataset from: {self.config.tokenized_dataset_path}"
                )
                tokenized_dataset = datasets.load_from_disk(
                    self.config.tokenized_dataset_path
                )
            else:
                dataset = get_dataset_from_name(
                    name=self.config.dataset_name,
                    dataset_cache_dir=self.config.dataset_cache_dir,
                )
                print(
                    f"Mapping and saving tokenized dataset: {self.config.tokenized_dataset_path}"
                )
                tokenized_dataset = dataset.map(
                    tokenize, batched=True, remove_columns=dataset["train"].column_names
                )
                tokenized_dataset.save_to_disk(self.config.tokenized_dataset_path)

        tokenizer.pad_token = tokenizer.eos_token

        nesim_loss_calculator = NesimLoss(
            model=model, config=self.config.nesim_config, device="cuda:0"
        )

        if self.config.resume_wandb_id is not None:
            print(f"Resuming logging on run ID: {self.config.resume_wandb_id}")
            wandb.init(resume="must", id=self.config.resume_wandb_id)

        self.train_step_idx = 0

        class CustomTrainer(Trainer):
            train_step_idx = 0
            config = self.config

            def compute_loss(inner_self, model, inputs, return_outputs=False):
                start = time.time()
                labels = inputs["input_ids"]
                outputs = model(**inputs)
                logits = outputs.logits

                # Shift logits and labels by one position
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()

                # Flatten the logits and labels
                loss = F.cross_entropy(
                    shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
                )

                training_loss = loss.item()
                # add nesim loss here
                if (
                    self.train_step_idx % self.config.apply_nesim_every_n_steps
                    == 0
                ):
                    nesim_loss = nesim_loss_calculator.compute(reduce_mean=True)
                    nesim_loss_calculator.wandb_log()

                    ## lower bound
                    neighbourhood_cosine_similarity_loss = (
                        nesim_loss_calculator.layer_handlers[0].layer_loss.get_loss()
                    )
                    if (
                        not neighbourhood_cosine_similarity_loss.item()
                        < self.config.neighbourhood_cosine_similarity_loss_lower_bound
                    ):
                        if nesim_loss is not None:
                            wandb.log({"combined_pyramid_loss": nesim_loss.item()})
                            loss = loss + nesim_loss.to(loss.device)
                    else:
                        pass
                end = time.time()
                time_taken = end - start

                # Generate sample text
                if self.train_step_idx % self.config.generate_sample_text_every_n_steps == 0:
                    sample_texts = []
                    for prompt in self.config.sample_prompts:
                        sample_text = self.generate_sample_text(
                           model, tokenizer, prompt, self.config.num_sample_completion_tokens
                        )
                        sample_texts.append(sample_text)
                else:
                    sample_texts = None
                
                log_data = {
                        "training_loss": training_loss,
                        "seconds_per_compute_loss": time_taken,
                    }
                print(log_data)
                if sample_texts is not None:
                    log_data["sample_generations"] = sample_texts
                    print(f"Samples:")
                    for x in sample_texts:
                        print(x)

                wandb.log(
                    log_data
                )
                self.train_step_idx += 1
                if return_outputs:
                    return loss, outputs
                else:
                    return loss

        if "test" not in tokenized_dataset:
            # XXXX
            tokenized_dataset = tokenized_dataset["train"].train_test_split(
                test_size=0.1
            )
            print(f"\033[91mSUCCESSFULLY DID A TRAIN TEST SPLIT\033[0m")

        print("Dataset Details:")
        print(tokenized_dataset)

        custom_trainer = CustomTrainer(
            model=model,
            args=self.config.training_arguments,
            train_dataset=tokenized_dataset["train"],
            eval_dataset=tokenized_dataset["test"],
        )
        if self.config.training_arguments.resume_from_checkpoint is None:
            custom_trainer.train()
        else:
            print(
                f"Resuming from checkpoint:\n{self.config.training_arguments.resume_from_checkpoint}"
            )
            custom_trainer.train(self.config.training_arguments.resume_from_checkpoint)
        wandb.finish()
        print("EXPERIMENT COMPLETE")
