import pytorch_lightning as pl
from lightning.pytorch.utilities.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from datasets import Dataset
from multiprocessing import cpu_count
from datasets import load_from_disk
from torch import optim
from torchmetrics import Metric
from torch.nn.functional import log_softmax
from torch.nn import KLDivLoss
import torch
import os
from collections import Counter
import math
from typing import TypeVar, Optional, Iterator
from utils import ALPACA_IGNORE_INDEX, alpaca_preprocess

import torch
import torch.distributed as dist
import transformers
from typing import Dict, Optional, Sequence
import copy


T_co = TypeVar('T_co', covariant=True)


class DistributedAlignmentSampler(DistributedSampler):
    def __init__(self, dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, seed: int = 0, batch_size: int = 8):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = False
        self.shuffle = False
        self.seed = seed
        self.batch_size = batch_size
        self.num_samples = math.ceil(len(self.dataset) / (self.num_replicas * self.batch_size)) * self.batch_size # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[T_co]:
        
        indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        # add extra samples to make it evenly divisible
        padding_size = self.total_size - len(indices)
        if padding_size <= len(indices):
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]

        assert len(indices) == self.total_size

        # subsample (Always keep the preprocessed batch intact)
        idxs = [i for i in range(len(indices)) if ((i // self.batch_size) - self.rank) % self.num_replicas == 0]
        indices = [indices[i] for i in idxs]
        assert len(indices) == self.num_samples
        return iter(indices)


class LogitsMetrics(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, logits, label_spaces_ids, sample_to, gold_classes):
        softmax = torch.nn.Softmax(dim=-1)
        logits = softmax(logits)
        batch_size, option_num, max_seq_len = label_spaces_ids.shape
        scores = torch.ones_like(label_spaces_ids, dtype=torch.float)
        scores = scores.to(gold_classes.device)

        for i in range(batch_size):
            seq_len = sample_to[i]
            for j in range(option_num):
                for k in range(min(seq_len, logits.shape[1])):
                    scores[i][j][k] = logits[i][k][label_spaces_ids[i][j][k]]

        scores = torch.prod(scores, dim=-1)
        classes = torch.argmax(scores, dim=-1)
        assert len(classes) == len(gold_classes)
        self.correct += torch.sum(classes == gold_classes)
        self.total += len(classes)

    def compute(self):
        return self.correct.float() / self.total.float()
    

class LitAlignmentDataModule(pl.LightningDataModule):

    def __init__(self, train_batch_size, test_batch_size, train_set, test_dir, distributed, tokenizer, isAlpaca):
        super(LitAlignmentDataModule, self).__init__()
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.train_set = train_set
        self.tokenizer = tokenizer
        self.test_dir = test_dir
        self.isAlpaca = isAlpaca
        if self.isAlpaca:
            self.alpaca_training_tokenizer = transformers.LlamaTokenizer.from_pretrained(
                "/work/frink/models/alpaca-7b",
                model_max_length=512,
                padding_side="right",
                use_fast=False,
            )
        self.distributed = distributed

        self.num_workers = cpu_count()

    def train_collate(self, batch):
        batch = [b.values() for b in batch]
        input_text, output_text, prototype, task_id, instance_id = list(zip(*batch))
        assert len(set(task_id)) == 1 and len(set(instance_id)) == 1
        assert prototype[0] == True and len(set(prototype[1:])) == 1 and prototype[-1] == False
        assert len(input_text) == len(output_text)
        if not self.isAlpaca:
            batch = self.tokenizer(
                text=input_text,
                text_target=output_text, 
                padding=True, 
                truncation=True, 
                return_tensors="pt", 
                max_length=512
            )
            batch["labels"][batch["labels"] == self.tokenizer.pad_token_id] = -100
        else:
            batch = alpaca_preprocess(input_text, output_text, self.tokenizer)

        return batch
    
    def test_collate(self, batch):
        batch = [b.values() for b in batch]
        input_text, output_text, label_spaces = list(zip(*batch))
        assert len(input_text) == len(output_text) == len(label_spaces)
        if not self.isAlpaca:
            batch = self.tokenizer(text=input_text, text_target=output_text, padding='longest', truncation=True,
                                return_tensors="pt", max_length=512)
        else:
            batch = alpaca_preprocess(input_text, output_text, self.alpaca_training_tokenizer)
            
        try:
            labels_cls = torch.ShortTensor([label_space.index(y) for label_space, y in zip(label_spaces, output_text)])
        except ValueError:
            print(label_spaces)
            print(output_text)
            raise

        label_spaces_ids = [self.tokenizer(label_space, padding=False, return_length=True) for label_space in label_spaces]
        sample_to = [label_space["length"] for label_space in label_spaces_ids]
        max_seq_len = max([max(leng) for leng in sample_to])
        label_spaces_ids = [self.tokenizer(label_space, padding="max_length", max_length=max_seq_len,
                                           return_tensors="pt")["input_ids"] for label_space in label_spaces]
        label_spaces_ids = torch.stack(label_spaces_ids, dim=0)

        sample_to = torch.ShortTensor([min(lengths) for lengths in sample_to])
        if self.isAlpaca:
            label_spaces_ids = label_spaces_ids[:, :, 1:] # remove the <s> token
            sample_to -= 1
            max_seq_len -= 1
            
        batch["label_cls"] = labels_cls
        batch["label_spaces_ids"] = label_spaces_ids
        batch["sample_to"] = sample_to
        return batch

    def prepare_data(self) -> None:
        
        self.test_sets = []
        
        # Prepare Test Data
        if self.test_dir is not None:
            test_datasets = [f for f in os.listdir(self.test_dir) if f != ".DS_Store"]
            for dataset in test_datasets:
                dataset_dir = os.path.join(self.test_dir, dataset)
                self.test_sets.append({
                    "name": dataset,
                    "setting": "Observed",
                    "dataset": load_from_disk(os.path.join(dataset_dir, "observed"))
                })
                self.test_sets.append({
                    "name": dataset,
                    "setting": "Unobserved",
                    "dataset": load_from_disk(os.path.join(dataset_dir, "unobserved"))
                })

        # Prepare Train Data
        batch_size_minus_one = self.train_batch_size - 1
        train_stats = Counter([data["task_id"] for data in self.train_set if data["instance_id"] == 0 and data["prototype"] == False])
        
        # Check if the number of paraphrases in Train/Val set for each task can be fit into batches reugularly
        for key in train_stats.keys():
            if train_stats[key] % batch_size_minus_one != 0:
                raise ValueError("Train set task {} has {} instances, which is not divisible by batch_size ({}) - 1!".format(
                    key, train_stats[key], self.train_batch_size
                ))

        prototype_train_dict, paraphrase_train_dict = {}, {}
        for train_data in self.train_set:
            if train_data["prototype"]:
                if train_data["task_id"] not in prototype_train_dict.keys():
                    prototype_train_dict[train_data["task_id"]] = {}
                prototype_train_dict[train_data["task_id"]][train_data["instance_id"]] = train_data
            else:
                if train_data["task_id"] not in paraphrase_train_dict.keys():
                    paraphrase_train_dict[train_data["task_id"]] = {}
                if train_data["instance_id"] not in paraphrase_train_dict[train_data["task_id"]].keys():
                    paraphrase_train_dict[train_data["task_id"]][train_data["instance_id"]] = []
                paraphrase_train_dict[train_data["task_id"]][train_data["instance_id"]].append(train_data)
        
        # Each batch of teh train set as [prototype, paraphrase1, paraphrase2, ...]
        processed_train_set = []
        for task_id in prototype_train_dict.keys():
            for instance_id in prototype_train_dict[task_id].keys():
                for i in range(len(paraphrase_train_dict[task_id][instance_id]) // batch_size_minus_one):
                    batch = [prototype_train_dict[task_id][instance_id], *paraphrase_train_dict[task_id][instance_id][i*batch_size_minus_one : (i+1)*batch_size_minus_one]]
                    processed_train_set.extend(batch)
        processed_train_set = Dataset.from_list(processed_train_set)
        self.train_set = processed_train_set
    
    def setup(self, stage: str) -> None:
        pass

    def train_dataloader(self):
        sampler = SequentialSampler(self.train_set) if not self.distributed else DistributedAlignmentSampler(self.train_set, batch_size=self.train_batch_size)
        return DataLoader(self.train_set, batch_size=self.train_batch_size, collate_fn=self.train_collate, shuffle=False, sampler=sampler)
    
    def test_dataloaders(self):
        for dataset_dict in self.test_sets:
            dataset_dict["dataset"] = DataLoader(dataset_dict["dataset"], batch_size=self.test_batch_size, collate_fn=self.test_collate)
        return self.test_sets
    
    

class AlignmentT5ForConditionalGeneration(pl.LightningModule):
    def __init__(self, model, learning_rate, weight_decay, alignment_weight, soft_prompt, **kwargs):
        super().__init__()
        self.test_metric = LogitsMetrics()
        self.model = model
        self.isAlpaca = False
        self.soft_prompt = soft_prompt
        self.alignment_weight = alignment_weight
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.kl_loss = KLDivLoss(reduction="batchmean", log_target=True)

    def training_step(self, batch, batch_idx):
        
        if self.alignment_weight == 0:
            outputs = self.model(**batch)
            loss = outputs.loss
            self.log("train_loss", loss)
            self.log("train_loss_ce", loss)
            self.log("train_loss_kl", 0)
            return loss
        else:        
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]
            
            main_inputs = {
                "input_ids": input_ids[0, :].unsqueeze(0),
                "attention_mask": attention_mask[0, :].unsqueeze(0),
                "labels": labels[0, :].unsqueeze(0),
            }
            main_loss = self.model(**main_inputs).loss
            outputs = self.model(**batch)
            dist = log_softmax(outputs.logits, dim=-1)
            prototype_dist = dist[0, :, :].unsqueeze(0).repeat(dist.shape[0] - 1, 1, 1)
            paraphrase_dist = dist[1:, :, :]
            alignment_loss = self.kl_loss(paraphrase_dist, prototype_dist)
            loss = (1 - self.alignment_weight) * main_loss + self.alignment_weight * alignment_loss
            self.log("train_loss", loss)
            self.log("train_loss_ce", main_loss)
            self.log("train_loss_kl", alignment_loss)
            return loss
        
    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, _, label_cls, label_spaces_ids, sample_to = batch.values()
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "max_new_tokens": 512
        }
        outputs = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
        scores = outputs.scores
        logits = torch.stack(scores, dim=1)
        self.test_metric(logits, label_spaces_ids, sample_to, label_cls)
        acc = self.test_metric.compute()
        self.log("test_acc", acc)
        return {"test_acc": acc}
    
    def on_test_epoch_start(self) -> None:
        self.test_metric.reset()
        return super().on_test_epoch_start()

    def configure_optimizers(self):
        return optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
    
    

class AlignmentLlamaForCausalLM(pl.LightningModule):
    def __init__(self, model, learning_rate, weight_decay, alignment_weight, soft_prompt, negation_alignment, **kwargs):
        super().__init__()
        self.test_metric = LogitsMetrics()
        self.model = model
        self.isAlpaca = True
        self.soft_prompt = soft_prompt
        negation_alignment = 1 if not negation_alignment else -1
        
        self.alignment_weight = alignment_weight * negation_alignment
        
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.kl_loss = KLDivLoss(reduction="batchmean", log_target=True)

    def training_step(self, batch, batch_idx):
        
        if self.alignment_weight == 0:
            outputs = self.model(**batch)
            loss = outputs.loss
            self.log("train_loss", loss)
            self.log("train_loss_ce", loss)
            self.log("train_loss_kl", 0)
            return loss
        else:        
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]
            
            main_inputs = {
                "input_ids": input_ids[0, :].unsqueeze(0),
                "attention_mask": attention_mask[0, :].unsqueeze(0),
                "labels": labels[0, :].unsqueeze(0),
            }
            main_loss = self.model(**main_inputs).loss
            outputs = self.model(**batch)
            
            logits = outputs.logits
            if self.soft_prompt:
                seq_len_with_soft_prompt = logits.shape[1]
                logits = logits[:, (seq_len_with_soft_prompt - input_ids.shape[-1]):, :] # remove logits predicted from the soft prompt
        
            dist = log_softmax(logits, dim=-1)
            prototype_dist = dist[0, :, :].unsqueeze(0).repeat(dist.shape[0] - 1, 1, 1)
            paraphrase_dist = dist[1:, :, :]
            alignment_loss = self.kl_loss(paraphrase_dist, prototype_dist)
            loss = (1 - self.alignment_weight) * main_loss + self.alignment_weight * alignment_loss
            self.log("train_loss", loss)
            self.log("train_loss_ce", main_loss)
            self.log("train_loss_kl", alignment_loss)
            return loss
        
    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, _, label_cls, label_spaces_ids, sample_to = batch.values()
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "max_new_tokens": 512
        }
        outputs = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
        scores = outputs.scores
        logits = torch.stack(scores, dim=1)
        self.test_metric(logits, label_spaces_ids, sample_to, label_cls)
        acc = self.test_metric.compute()
        self.log("test_acc", acc)
        return {"test_acc": acc}
    
    def on_test_epoch_start(self) -> None:
        self.test_metric.reset()
        return super().on_test_epoch_start()

    def configure_optimizers(self):
        return optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
