import torch
import torch.nn as nn
from torch.utils.data import random_split, ConcatDataset, Subset

from dataset.cc12m import cc12m
from trainer.finetune import FinetuneCLIP, FinetuneFFN, FinetuenProj


class CondFT(FinetuneCLIP):

    def get_iterator(self, dataset, task):
        if self.args.balanced_buffer and task > 0:
            trainset = dataset.get_dataset(task, is_train=True, with_buffer=False)
            bufferset = dataset.get_buffer(task) if task > 0 else None

        else:
            trainset = dataset.get_dataset(task, is_train=True, with_buffer=(self.args.buffer_size > 0))
            bufferset = None
        if self.args.valid:
            per_task_valid = 100
            trainset, valid_train = random_split(trainset, [len(trainset) - per_task_valid, per_task_valid])
            if task > 0:
                bufferset, valid_buffer = random_split(bufferset,
                                                       [len(bufferset) - task * per_task_valid, task * per_task_valid])
                validset = ConcatDataset([valid_train, valid_buffer])
            else:
                validset = trainset
        else:
            validset = None
        condset = cc12m(transform=dataset.transform)
        if bufferset is not None and 'first' not in self.args.method:

            index = torch.randperm(len(condset))[:len(bufferset)]
            bufferset = ConcatDataset([bufferset, Subset(condset, index)])

        else:
            bufferset = condset

        buffer_loader = self.get_loader(bufferset)
        train_dataloader = self.get_loader(trainset)
        total_batches = len(train_dataloader)

        return train_dataloader, buffer_loader, validset, total_batches

    def compute_loss(self, batch, model, **kwargs):
        buffer = kwargs.get('buffer', None)

        loss_img = nn.CrossEntropyLoss()
        loss_txt = nn.CrossEntropyLoss()

        (images, label, texts) = batch
        (images_cond, _, texts_cond) = buffer
        images = images.to(self.args.device)
        texts = texts.to(self.args.device)
        cur_bs = images.size(0)
        images_cond = images_cond[:cur_bs].to(self.args.device)
        texts_cond = texts_cond[:cur_bs].to(self.args.device)

        images = torch.cat([images, images_cond])
        texts = torch.cat([texts, texts_cond])

        ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
        logits_per_image, logits_per_text = model(images, texts)
        cur_loss = (loss_img(logits_per_image[:cur_bs], ground_truth[:cur_bs]) + loss_txt(logits_per_text[:cur_bs],
                                                                                          ground_truth[:cur_bs])) / 2
        cond_loss = (loss_img(logits_per_image[cur_bs:], ground_truth[cur_bs:]) + loss_txt(logits_per_text[cur_bs:],
                                                                                           ground_truth[cur_bs:])) / 2
        return cur_loss + cond_loss * self.args.scale


class CondFTFFN(CondFT, FinetuneFFN):
    def unfreeze_model(self, model):
        FinetuneFFN.unfreeze_model(self, model)


class CondFTProj(CondFT, FinetuenProj):
    def unfreeze_model(self, model):
        FinetuenProj.unfreeze_model(self, model)
