#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import print_function
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class conv_EEG(nn.Module):
    def __init__(self, out_class):
        super(conv_EEG, self).__init__()
        self.conv1 = nn.Conv1d(1, 5, 3, stride=1)
        self.bn1 = nn.BatchNorm1d(5)
        self.lr1 = nn.LeakyReLU(0.3)
        self.mxp = nn.MaxPool1d(3)
        self.conv2 = nn.Conv1d(5, 10, 3, stride=1)
        self.bn2 = nn.BatchNorm1d(10)
        self.lr2 = nn.LeakyReLU(0.3)
        self.fn = nn.Flatten()

        self.classifier = nn.Sequential(

            nn.Linear(3060, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, out_class))

    def forward(self, x):
        output = self.conv1(x)

        output = self.bn1(output)
        output = self.lr1(output)

        output = self.conv2(output)
        output = self.bn2(output)
        output = self.lr2(output)

        output = self.fn(output)

        output = self.classifier(output)

        return output

class conv_EEG_Feature(nn.Module):
    def __init__(self, out_class):
        super(conv_EEG_Feature, self).__init__()
        self.conv1 = nn.Conv1d(1, 5, 3, stride=1)
        self.bn1 = nn.BatchNorm1d(5)
        self.lr1 = nn.LeakyReLU(0.3)
        self.mxp = nn.MaxPool1d(3)
        self.conv2 = nn.Conv1d(5, 10, 3, stride=1)
        self.bn2 = nn.BatchNorm1d(10)
        self.lr2 = nn.LeakyReLU(0.3)
        self.fn = nn.Flatten()
        self.fc = nn.Linear(64, out_class)

        self.features = nn.Sequential(
            nn.Linear(3060, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
        )

    def forward(self, x):
        output = self.conv1(x)

        output = self.bn1(output)
        output = self.lr1(output)

        output = self.conv2(output)
        output = self.bn2(output)
        output = self.lr2(output)
        output = self.fn(output)

        features = self.features(output)
        logits = self.fc(features)
        return logits, F.normalize(features, dim=1)

class PiCO(nn.Module):

    def __init__(self, args, base_encoder):
        super().__init__()

        self.encoder_q = base_encoder(args.num_class)
        # momentum encoder
        self.encoder_k = base_encoder(args.num_class)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(args.moco_queue, args.low_dim))
        self.register_buffer("queue_pseudo", torch.randn(args.moco_queue, args.num_class))
        self.register_buffer("queue_partial", torch.randn(args.moco_queue, args.num_class))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
        self.register_buffer("prototypes", torch.zeros(args.num_class, args.low_dim))

    @torch.no_grad()
    def _momentum_update_key_encoder(self, args):
        """
        update momentum encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * args.moco_m + param_q.data * (1. - args.moco_m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys, labels, partial_Y, args):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        assert args.moco_queue % batch_size == 0  # for simplicity
        # replace the keys at ptr (dequeue and enqueue)
        self.queue[ptr:ptr + batch_size, :] = keys
        self.queue_pseudo[ptr:ptr + batch_size, :] = labels
        self.queue_partial[ptr:ptr + batch_size, :] = partial_Y
        ptr = (ptr + batch_size) % args.moco_queue  # move pointer
        self.queue_ptr[0] = ptr

    def reset_prototypes(self, prototypes):
        self.prototypes = prototypes

    def forward(self, img_q, im_k=None, partial_Y=None, args=None, eval_only=False):
        output, q = self.encoder_q(img_q)
        if eval_only:
            return output
        # for testing
        predicted_scores = torch.softmax(output, dim=1) * partial_Y
        max_scores, pseudo_labels = torch.max(predicted_scores, dim=1)
        # using partial labels to filter out negative labels
        # compute protoypical logits
        prototypes = self.prototypes.clone().detach()
        logits_prot = torch.mm(q, prototypes.t())  # Corresponds to Equation 6 of the original PiCO formula.
        score_prot = torch.softmax(logits_prot, dim=1)
        # update momentum prototypes with pseudo labels
        for feat, label, max_score in zip(q, pseudo_labels, max_scores):
            self.prototypes[label] = self.prototypes[label] * args.proto_m + (1 - args.proto_m) * feat
        # normalize prototypes
        self.prototypes = F.normalize(self.prototypes, p=2, dim=1)
        # compute key features
        with torch.no_grad():  # no gradient
            self._momentum_update_key_encoder(args)  # update the momentum encoder
            # shuffle for making use of BN
            # im_k, predicted_scores, partial_Y, idx_unshuffle = self._batch_shuffle_ddp(im_k, predicted_scores, partial_Y)
            _, k = self.encoder_k(im_k)
            # undo shuffle
            # k, predicted_scores, partial_Y = self._batch_unshuffle_ddp(k, predicted_scores, partial_Y, idx_unshuffle)
        features = torch.cat((q, k, self.queue.clone().detach()), dim=0)
        pseudo_scores = torch.cat((predicted_scores, predicted_scores, self.queue_pseudo.clone().detach()), dim=0)
        partial_target = torch.cat((partial_Y, partial_Y, self.queue_partial.clone().detach()), dim=0)
        # to calculate SupCon Loss using pseudo_labels and partial target
        # dequeue and enqueue
        self._dequeue_and_enqueue(k, predicted_scores, partial_Y, args)
        return output, features, pseudo_scores, partial_target, score_prot

class PaPi(nn.Module):
    def __init__(self, args, base_encoder):
        super().__init__()
        self.proto_weight = args.proto_m
        self.encoder = base_encoder(args.num_class)
        self.register_buffer("prototypes", torch.zeros(args.num_class, args.low_dim))

    def set_prototype_update_weight(self, epoch, args):
        start = float(args.pro_weight_range.split(",")[0])
        end = float(args.pro_weight_range.split(",")[1])
        self.proto_weight = 1. * epoch / args.epochs * (end - start) + start

    def forward(self, img_q, img_k=None, img_q_mix=None, img_k_mix=None, partial_Y=None, args=None,
                eval_only=False):

        output_q, q = self.encoder(img_q)

        if eval_only:
            return output_q

        output_k, k = self.encoder(img_k)

        output_q_mix, q_mix = self.encoder(img_q_mix)
        output_k_mix, k_mix = self.encoder(img_k_mix)

        predicetd_scores_q = torch.softmax(output_q, dim=1) * partial_Y

        predicetd_scores_q_norm = predicetd_scores_q / predicetd_scores_q.sum(dim=1).repeat(args.num_class,
                                                                                            1).transpose(0, 1)

        predicetd_scores_k = torch.softmax(output_k, dim=1) * partial_Y
        predicetd_scores_k_norm = predicetd_scores_k / predicetd_scores_k.sum(dim=1).repeat(args.num_class,
                                                                                            1).transpose(0, 1)

        max_scores_q, pseudo_labels_q = torch.max(predicetd_scores_q_norm, dim=1)
        max_scores_k, pseudo_labels_k = torch.max(predicetd_scores_k_norm, dim=1)

        prototypes = self.prototypes.clone().detach()

        logits_prot_q = torch.mm(q, prototypes.t())
        logits_prot_k = torch.mm(k, prototypes.t())

        logits_prot_q_mix = torch.mm(q_mix, prototypes.t())
        logits_prot_k_mix = torch.mm(k_mix, prototypes.t())

        for feat_q, label_q in zip(q, pseudo_labels_q):
            self.prototypes[label_q] = self.proto_weight * self.prototypes[label_q] + (1 - self.proto_weight) * feat_q

        for feat_k, label_k in zip(k, pseudo_labels_k):
            self.prototypes[label_k] = self.proto_weight * self.prototypes[label_k] + (1 - self.proto_weight) * feat_k

        self.prototypes = F.normalize(self.prototypes, p=2, dim=1)

        return output_q, output_k, logits_prot_q, logits_prot_k, logits_prot_q_mix, logits_prot_k_mix

class PG(nn.Module):

    def __init__(self, args, base_encoder):
        super(PG, self).__init__()
        self.encoder_q = base_encoder(args.num_class)
        self.register_buffer("prototypes", torch.zeros(args.num_class, args.low_dim))

    def forward(self, img_q, partial_Y=None, args=None, eval_only=False):
        output, q = self.encoder_q(img_q)
        # for testing
        if eval_only:
            return output
        predicted_scores = torch.softmax(output, dim=1) * partial_Y
        max_scores, pseudo_labels = torch.max(predicted_scores, dim=1)
        # compute protoypical logits
        prototypes = self.prototypes.clone().detach()
        logits_prot = torch.mm(q, prototypes.t())  # Corresponds to Equation 6 of the original PiCO formula.

        # update momentum prototypes with pseudo labels
        for feat, label, max_score in zip(q, pseudo_labels, max_scores):
            self.prototypes[label] = self.prototypes[label] * args.proto_m + (1 - args.proto_m) * feat
        # normalize prototypes
        self.prototypes = F.normalize(self.prototypes, p=2, dim=1)

        #return: Current batch's prediction results, features of q+k+queue, prediction scores of q+k+queue, partial labels of q+k+queue,similarity scores between the current batch's prediction results and prototypes, the prototype classification logits
        return output, logits_prot


# If you want to use the encoder "MLP", you can use the following classes

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.module = nn.Sequential(
            nn.Linear(310, 256),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.module(x)
        return x


class conv_EEG_Feature_MLP(nn.Module):
    def __init__(self, out_class):
        super(conv_EEG_Feature_MLP, self).__init__()
        self.backbone = MLP()
        self.features = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
        )
        self.fc = nn.Linear(64, out_class)

    def forward(self, x):
        x = x.reshape(8, 310)
        x = self.backbone(x)
        features = self.features(x)
        logits = self.fc(features)
        return logits, F.normalize(features, dim=1)

class conv_EEG_MLP(nn.Module):
    def __init__(self, out_class):
        super(conv_EEG_MLP, self).__init__()
        self.backbone = MLP()
        self.fc = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, out_class))

    def forward(self, x):
        x = x.reshape(8, 310)
        features = self.backbone(x)
        logits = self.fc(features)
        return logits