import copy

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from .bases import ProxyBase
from .utils import clone_state_dict


class BiGanProxyBase(ProxyBase):

    def __init__(self, p, q, datasets_p, datasets_q, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.datasets_p = copy.deepcopy(datasets_p)
        self.datasets_q = copy.deepcopy(datasets_q)
        self._p = copy.deepcopy(p)
        self._q = copy.deepcopy(q)

        # remove gradient calculations!!
        for param in self._p.parameters():
            param.requires_grad = False
        for param in self._q.parameters():
            param.requires_grad = False

        self._p = self._p.to(device=self.device)
        self._q = self._q.to(device=self.device)

        self._n_train_samples_p = 0
        self._n_val_samples_p = 0
        self._n_train_samples_q = 0
        self._n_val_samples_q = 0
        self.dataloader_train_p = self._get_dataloader(self.datasets_p[0])
        self.dataloader_val_p = self._get_dataloader(self.datasets_p[1])
        self.dataloader_train_q = self._get_dataloader(self.datasets_q[0])
        self.dataloader_val_q = self._get_dataloader(self.datasets_q[1])

    def loss_with_generator_grad(self, p, q):
        data = self._get_data(train=True)
        # calculate loss
        return self._objective_fn(data, p, q, with_no_grad=False).mean()

    def _objective_fn(self, data, p, q, with_no_grad=True):
        raise NotImplementedError()

    def _loss_fn(self, data):
        return self._objective_fn(data, self._p, self._q).mean()

    def _get_dataloader(self, dataset):
        num_workers = (0 if dataset.annealing
                       else self._num_workers_per_dataset)
        dataloader = DataLoader(dataset, batch_size=self.batch_size,
                                shuffle=True, num_workers=num_workers,
                                pin_memory=True)

        return iter(dataloader)

    def _get_data(self, train=True):
        if train:
            self._n_train_samples_p += self.batch_size
            if self._n_train_samples_p > len(self.datasets_p[0]):
                self.dataloader_train_p = self._get_dataloader(self.datasets_p[0])
                self._n_train_samples_p = 0

            self._n_train_samples_q += self.batch_size
            if self._n_train_samples_q > len(self.datasets_q[0]):
                self.dataloader_train_q = self._get_dataloader(self.datasets_q[0])
                self._n_train_samples_q = 0
            return (next(self.dataloader_train_p), next(self.dataloader_train_q))
        else:
            self._n_val_samples_p += self.batch_size
            self._n_val_samples_q += self.batch_size
            p_ran_out = self._n_val_samples_p > len(self.datasets_p[1])
            q_ran_out = self._n_val_samples_q > len(self.datasets_q[1])
            if p_ran_out and q_ran_out:
                # ran out of validation data reset everything and return
                # None to indicate validation is over
                self.dataloader_val_p = self._get_dataloader(self.datasets_p[1])
                self.dataloader_val_q = self._get_dataloader(self.datasets_q[1])
                self._n_val_samples_p = 0
                self._n_val_samples_q = 0
                return None
            elif p_ran_out:
                self.dataloader_val_p = self._get_dataloader(self.datasets_p[0])
            elif q_ran_out:
                self.dataloader_val_q = self._get_dataloader(self.datasets_q[0])
            return (next(self.dataloader_val_p), next(self.dataloader_val_q))

    def _load_generator_statedicts(self, generator_statedicts):
        p_state_dict, q_state_dict = generator_statedicts
        new_p_state_dict = clone_state_dict(p_state_dict, detach=True)
        new_q_state_dict = clone_state_dict(q_state_dict, detach=True)

        self._p.load_state_dict(new_p_state_dict, strict=True)
        self._q.load_state_dict(new_q_state_dict, strict=True)

    def __call__(self, Xp, Xq):
        # turn iff gradients for proxy
        for p in self.proxy_model.parameters():
            p.requires_grad = False
        self.proxy_model.eval()
        y = self.proxy_model(Xp, Xq)
        # turn on gradients for proxy
        for p in self.proxy_model.parameters():
            p.requires_grad = True
        self.proxy_model.train()
        return y


class BiGanClassifierProxy(BiGanProxyBase):

    def __init__(self, *args, **kwargs):
        super(BiGanProxyBase, self).__init__(*args, **kwargs)
        self._CELoss = nn.CrossEntropyLoss(reduction='none')

    def _objective_fn(self, data, p, q, with_no_grad=True):
        ((real_Xp, Yp), (real_Xq, Yq)) = data
        real_Xp = real_Xp.to(device=self.device)
        real_Xq = real_Xq.to(device=self.device)
        Yp = Yp.to(device=self.device)
        Yq = Yq.to(device=self.device)

        if with_no_grad:
            with torch.no_grad():
                fake_Xq = q(real_Xp)
        else:
            fake_Xq = q(real_Xp)
        output_q = self.proxy_model(real_Xp, fake_Xq)
        loss_q = self._CELoss(output_q, Yq)

        if with_no_grad:
            with torch.no_grad():
                fake_Xp = p(real_Xq)
        else:
            fake_Xp = p(real_Xq)
        output_p = self.proxy_model(fake_Xp, real_Xq)
        loss_p = self._CELoss(output_p, Yp)

        return 0.5*(loss_p + loss_q)


class BiGanWCritic(BiGanProxyBase):

    def __init__(self, *args, **kwargs):
        super(BiGanWCritic, self).__init__(*args, **kwargs)
        self.clamp_limit = self.proxy_model.clamp_limit

        self._manipulate_weights(self.proxy_model.parameters())

    def _objective_fn(self, data, p, q, with_no_grad=True):
        ((real_Xp, _), (real_Xq, _)) = data
        real_Xp = real_Xp.to(device=self.device)
        real_Xq = real_Xq.to(device=self.device)
        if with_no_grad:
            with torch.no_grad():
                fake_Xq = q(real_Xp)
        else:
            fake_Xq = q(real_Xp)
        loss_q = self.proxy_model(real_Xp, fake_Xq)

        if with_no_grad:
            with torch.no_grad():
                fake_Xp = p(real_Xq)
        else:
            fake_Xp = p(real_Xq)
        loss_p = self.proxy_model(fake_Xp, real_Xq)

        return -(loss_p - loss_q)

    def _manipulate_weights(self, parameters):
        for param in parameters:
            param.data.clamp_(-self.clamp_limit, self.clamp_limit)
