import torch
import warnings
from tqdm import tqdm

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from src.nlp.models.DummyDistributed import DummyDistributed

class BaseExperiment:

    def __init__(self, cfg):
        self.cfg = cfg
        self.model = None

    def set_model(self, model):
        if self.model is not None:
            warnings.warn("Overwriting model...")
        self.model = model

    def distributed_setup(self):
        """
        Set up distributed computing.
        """
        local_rank, global_rank = self.model._distributed_setup()
        self.model = DDP(self.model, device_ids=[local_rank], find_unused_parameters=True)
        self.distributed = True
        self.local_rank = local_rank
        self.global_rank = global_rank

        warnings.warn(f"Default provided device {self.device} is being overridden by distributed setup, new device {self.local_rank}.")
        self.device = self.local_rank

        warnings.warn(f"DataLoader changed to DistributedSampler.")
        self.train_dataloader = DataLoader(
            self.train_data,
            batch_size=self.cfg.learning.batch_size,
            sampler=DistributedSampler(self.train_data),
            shuffle=False
        )

        self.val_dataloader = DataLoader(
            self.val_data,
            batch_size=self.cfg.learning.batch_size,
            sampler=DistributedSampler(self.val_data),
            shuffle=False
        )

    def dummy_distributed(self):
        """
        Function to wrap model in a dummy instance so model.module doesn't throw errors.
        Should be used for non-distributed model runs.
        """
        self.model = DummyDistributed(self.model)
    
    def get_datalength(self):
        """
        Using the tokenizer, iterate over the entire training data 
        and count the number of tokens in each sample, return as list.
        """
        datalength = []
        for sample in tqdm(self.train_data):
            text = sample[0]
            tokens = self.model.tokenizer(text, return_tensors="pt")
            datalength.append(len(tokens["input_ids"][0]))
        mean = sum(datalength) / len(datalength)
        std = (sum((x - mean) ** 2 for x in datalength) / len(datalength)) ** 0.5
        print(f"DATA LENGTH STATS: Mean: {mean}, Std: {std}")
        return datalength

    def evaluate(self, finetuned=False):
        raise NotImplementedError
    
    def finetune_pass(self, batch, model, device):
        raise NotImplementedError
