import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from lightning.pytorch.utilities.data import DataLoader
import transformers
from multiprocessing import cpu_count
from datasets import load_from_disk
from torch import optim
from torchmetrics import Metric
import torch
import os
from utils import ALPACA_IGNORE_INDEX, alpaca_preprocess


class LogitsMetrics(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, logits, label_spaces_ids, sample_to, gold_classes):
        softmax = torch.nn.Softmax(dim=-1)
        logits = softmax(logits)
        batch_size, option_num, max_seq_len = label_spaces_ids.shape
        scores = torch.ones_like(label_spaces_ids, dtype=torch.float)
        scores = scores.to(gold_classes.device)

        for i in range(batch_size):
            seq_len = sample_to[i]
            for j in range(option_num):
                for k in range(min(seq_len, logits.shape[1])):
                    scores[i][j][k] = logits[i][k][label_spaces_ids[i][j][k]]

        scores = torch.prod(scores, dim=-1)
        classes = torch.argmax(scores, dim=-1)
        assert len(classes) == len(gold_classes)
        self.correct += torch.sum(classes == gold_classes)
        self.total += len(classes)

    def compute(self):
        return self.correct.float() / self.total.float()
    

    

class LitTrainingDataModule(pl.LightningDataModule):

    def __init__(self, train_batch_size, test_batch_size, train_set, test_dir, distributed, tokenizer, isAlpaca):
        super(LitTrainingDataModule, self).__init__()
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.train_set = train_set
        self.tokenizer = tokenizer
        self.test_dir = test_dir
        self.test_sets = []
        self.isAlpaca = isAlpaca
        if self.isAlpaca:
            self.alpaca_training_tokenizer = transformers.LlamaTokenizer.from_pretrained(
                "/work/frink/models/alpaca-7b",
                model_max_length=512,
                padding_side="right",
                use_fast=False,
            )
        self.num_workers = cpu_count()

    def collate(self, batch):
        batch = [b.values() for b in batch]
        input_text, output_text = list(zip(*batch))
        assert len(input_text) == len(output_text)
        if not self.isAlpaca:
            input_encoding = self.tokenizer(
                input_text,
                padding="longest",
                max_length=512,
                truncation=True,
                return_tensors="pt",
            )
            input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask
            target_encoding = self.tokenizer(
                output_text,
                padding="longest",
                max_length=512,
                truncation=True,
                return_tensors="pt",
            )
            labels = target_encoding.input_ids
            labels[labels == self.tokenizer.pad_token_id] = -100

            batch = {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
            }
            """batch = self.tokenizer(
                text=input_text,
                text_target=output_text, 
                padding=True, 
                truncation=True, 
                return_tensors="pt", 
                max_length=512
            )
            batch["labels"][batch["labels"] == self.tokenizer.pad_token_id] = -100"""
        else:
            batch = alpaca_preprocess(input_text, output_text, self.alpaca_training_tokenizer)
        
        return batch
    
    def test_collate(self, batch):
        batch = [b.values() for b in batch]
        input_text, output_text, label_spaces = list(zip(*batch))
        assert len(input_text) == len(output_text) == len(label_spaces)

        batch = self.tokenizer(text=input_text, text_target=output_text, padding='longest', truncation=True, return_tensors="pt", max_length=512)
        try:
            labels_cls = torch.ShortTensor([label_space.index(y) for label_space, y in zip(label_spaces, output_text)])
        except ValueError:
            raise

        label_spaces_ids = [self.tokenizer(label_space, padding=False, return_length=True) for label_space in label_spaces]
        sample_to = [label_space["length"] for label_space in label_spaces_ids]
        max_seq_len = max([max(leng) for leng in sample_to])
        label_spaces_ids = [self.tokenizer(label_space, padding="max_length", max_length=max_seq_len,
                                           return_tensors="pt")["input_ids"] for label_space in label_spaces]
        label_spaces_ids = torch.stack(label_spaces_ids, dim=0)

        sample_to = torch.ShortTensor([min(lengths) for lengths in sample_to])
        if self.isAlpaca:
            label_spaces_ids = label_spaces_ids[:, :, 1:] # remove the <s> token
            sample_to -= 1
            max_seq_len -= 1
            
        batch["label_cls"] = labels_cls
        batch["label_spaces_ids"] = label_spaces_ids
        batch["sample_to"] = sample_to
        return batch

    def prepare_data(self) -> None:
        
        self.test_sets = []
        if self.test_dir is None:
            return
        
        test_datasets = [f for f in os.listdir(self.test_dir) if f != ".DS_Store"]
        for dataset in test_datasets:
            dataset_dir = os.path.join(self.test_dir, dataset)
            self.test_sets.append({
                "name": dataset,
                "setting": "Observed",
                "dataset": load_from_disk(os.path.join(dataset_dir, "observed"))
            })
            self.test_sets.append({
                "name": dataset,
                "setting": "Unobserved",
                "dataset": load_from_disk(os.path.join(dataset_dir, "unobserved"))
            })

    def setup(self, stage: str) -> None:
        pass

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.train_batch_size, collate_fn=self.collate)

    def test_dataloaders(self):
        for dataset_dict in self.test_sets:
            dataset_dict["dataset"] = DataLoader(dataset_dict["dataset"], batch_size=self.test_batch_size, collate_fn=self.test_collate)
        return self.test_sets    

class LitT5ForConditionalGeneration(pl.LightningModule):
    def __init__(self, model, learning_rate, weight_decay, soft_prompt, **kwargs):
        super().__init__()
        self.test_metric = LogitsMetrics()
        self.model = model
        self.isAlpaca = False
        self.soft_prompt = soft_prompt
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, _, label_cls, label_spaces_ids, sample_to = batch.values()
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "max_new_tokens": 512
        }
        outputs = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
        scores = outputs.scores
        logits = torch.stack(scores, dim=1)
        self.test_metric(logits, label_spaces_ids, sample_to, label_cls)
        acc = self.test_metric.compute()
        self.log("test_acc", acc)
        return {"test_acc": acc}
    
    def on_test_epoch_start(self) -> None:
        self.test_metric.reset()
        return super().on_test_epoch_start()

    def configure_optimizers(self):
        return optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
    


class LitLlamaForCausalLM(pl.LightningModule):
    def __init__(self, model, learning_rate, weight_decay, soft_prompt, **kwargs):
        super().__init__()
        self.test_metric = LogitsMetrics()
        self.model = model
        self.isAlpaca = True
        self.soft_prompt = soft_prompt
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, _, label_cls, label_spaces_ids, sample_to = batch.values()
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "max_new_tokens": 512
        }
        outputs = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
        scores = outputs.scores
        logits = torch.stack(scores, dim=1)
        self.test_metric(logits, label_spaces_ids, sample_to, label_cls)
        acc = self.test_metric.compute()
        self.log("test_acc", acc)
        return {"test_acc": acc}
    
    def on_test_epoch_start(self) -> None:
        self.test_metric.reset()
        return super().on_test_epoch_start()

    def configure_optimizers(self):
        return optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
