from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from torch.autograd import Variable
from torchvision.models import densenet121
from lightly.models.modules.heads import SimCLRProjectionHead
from torch.distributions import Normal, Independent

from .our_RNN_modules import Attention, DomainDiscriminator, DomainAdversarialLoss

IMG_HEIGHT = 224
NUM_CLASSES = 62

class FMoWNetwork(nn.Module):
    def __init__(self, args, weights=None, ssl_training=False):
        super(FMoWNetwork, self).__init__()
        self.args = args
        self.num_classes = NUM_CLASSES
        self.enc = densenet121(pretrained=True).features
        self.classifier = nn.Linear(1024, self.num_classes)
        if weights is not None:
            self.load_state_dict(deepcopy(weights))
        # SimCLR projection head
        if self.args.method == 'simclr':
            from lightly.models.modules.heads import SimCLRProjectionHead
            self.projection_head = SimCLRProjectionHead(1024, 1024, 128)
        # SwaV: projection head and prototypes
        elif self.args.method == 'swav':
            from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
            self.projection_head = SwaVProjectionHead(1024, 1024, 128)
            self.prototypes = SwaVPrototypes(128, n_prototypes=1024)
        self.ssl_training = ssl_training

    def reset_weights(self, weights):
        self.load_state_dict(deepcopy(weights))

    def forward(self, x):
        features = self.enc(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)

        if self.args.method == 'simclr' and self.ssl_training:
            return self.projection_head(out)
        elif self.args.method == 'swav' and self.ssl_training:
            out = self.projection_head(out)
            out = nn.functional.normalize(out, dim=1, p=2)
            out = self.prototypes(out)
            return out
        else:
            return self.classifier(out)




class EvoS_FMoWNetwork(nn.Module):
    def __init__(self, args, weights=None):
        super(EvoS_FMoWNetwork, self).__init__()
        self.args = args
        self.num_classes = NUM_CLASSES
        self.enc = densenet121(pretrained=True).features

        if args.dim_bottleneck_f is not None:
            self.bottleneck = nn.Sequential(
                nn.Linear(1024, args.dim_bottleneck_f),
                nn.BatchNorm1d(args.dim_bottleneck_f),
                nn.ReLU()
            )
            self.feature_dim = args.dim_bottleneck_f
        else:
            self.feature_dim = 1024

        if weights is not None:
            self.load_state_dict(deepcopy(weights))

        self.classifier = nn.Linear(self.feature_dim, NUM_CLASSES)

        self.attention_dict = {}
        self.attention_dict[1] = Attention(dim=2 * self.feature_dim, heads=args.num_head, dim_head=args.dim_head).cuda()

        self.conv1d_dict = {}
        for k in range(2, args.scale + 1):
            self.conv1d_dict[k] = nn.Unfold(kernel_size=(1, k), stride=1).cuda()  # overlapping window
            self.attention_dict[k] = Attention(dim=k * 2 * self.feature_dim, heads=k * args.num_head, dim_head=args.dim_head).cuda()

        self.init_pool = {0: nn.Parameter(torch.zeros([1, 2 * self.feature_dim], requires_grad=True)),
                          1: nn.Parameter(torch.zeros([1, 2 * self.feature_dim], requires_grad=True))}
        self.memory_pool = {}

        self.domain_discriminator = DomainDiscriminator(in_feature=self.feature_dim, hidden_size=self.args.hidden_discriminator).cuda()
        self.domain_adv = DomainAdversarialLoss(self.domain_discriminator).cuda()
        self.eps = 1e-6

    def reset_weights(self, weights):
        self.load_state_dict(deepcopy(weights))

    def memorize(self, timestamp, mean_logStd_t=None):
        with torch.no_grad():
            t = timestamp - self.args.init_timestamp
            if t <= 1:
                self.memory_pool[t] = Variable(self.init_pool[t].detach().clone(), requires_grad=False).cuda()
            else:
                self.memory_pool[t] = Variable(mean_logStd_t.detach().clone(), requires_grad=False).cuda()

    def get_previous_mean_logStd(self, timestamp):
        t = timestamp - self.args.init_timestamp
        if t < 2:
            raise RuntimeError(f' Timestamp {timestamp} does not require previous statistics!')
        else:
            previous_mean_logStd = None
            for i in range(0, t):
                mean_logStd_i = self.memory_pool[i]
                if previous_mean_logStd is None:
                    previous_mean_logStd = mean_logStd_i
                else:
                    previous_mean_logStd = torch.cat((previous_mean_logStd, mean_logStd_i), dim=0)
            return previous_mean_logStd.cuda()

    def foward_for_FeatureDistritbuion(self, previous_mean_logStd):
        '''
        previous_mean_logStd.shape: [t, d]
        '''
        future_out = None
        num_previous_domain = previous_mean_logStd.size(0)
        s = 0
        loss_con = None
        for k in range(1, self.args.scale + 1):
            if k == 1:
                out_k = self.attention_dict[k](previous_mean_logStd) # out_i.shape: [1, t-1, 2*feature_dim]
                future_out = torch.mean(out_k, dim=1)
                s += 1
            else:
                if (num_previous_domain < k + 1) or (self.args.scale > self.args.split_time - self.args.init_timestamp):
                    pass
                else:   # concatenate version
                    rearranged_input = previous_mean_logStd.view(1, previous_mean_logStd.shape[0], 1, previous_mean_logStd.shape[1])  # .shape: [1, t, 1, d]
                    patches = self.conv1d_dict[k](rearrange(rearranged_input, 'b w h c -> b c h w'))
                    patches = torch.reshape(patches, (1, previous_mean_logStd.shape[1], k*1, patches.shape[-1]))  # patches.shape:[1, d, k, slices]
                    patches = patches.permute(0, 3, 2, 1)    # to shape:[1, slices, k, d]
                    patches = rearrange(patches, 'b l k d -> b l (k d)')
                    concate_patches = patches.squeeze(0)   # to [l, kd]

                    out_k = self.attention_dict[k](concate_patches)       # out_k.shape: [1, l, kd]
                    last_out_k = torch.mean(out_k, dim=1)
                    last_out_k = last_out_k.view(1, k, 2*self.feature_dim)

                    future_out = future_out + last_out_k[:, -1, :]

                    temp_loss_con = torch.norm(previous_mean_logStd[-(k - 1):, :] - last_out_k.squeeze(0)[:k-1, :], p=2, dim=1)
                    temp_loss_con = temp_loss_con
                    if loss_con is None:
                        loss_con = torch.mean(temp_loss_con)
                    else:
                        loss_con += torch.mean(temp_loss_con)
                    s += 1
        future_out = future_out / s
        return future_out, loss_con


    def foward_encoder(self, x):
        f = self.enc(x)
        f = F.relu(f, inplace=True)
        f = F.adaptive_avg_pool2d(f, (1, 1))
        f = torch.flatten(f, 1)
        if self.args.dim_bottleneck_f is not None:
            f = self.bottleneck(f)
        return f

    def foward_classifier(self, normalized_f):
        logits = self.classifier(normalized_f)
        return logits

    def forward_evaluate(self, x, mean, std):
        '''
        :param mean: [B, A], A is the dimension of the deep feature
        :param std: [B, A]
        :return:
        '''
        f = self.enc(x)
        f = F.relu(f, inplace=True)
        f = F.adaptive_avg_pool2d(f, (1, 1))
        f = torch.flatten(f, 1)
        if self.args.dim_bottleneck_f is not None:
            f = self.bottleneck(f)
        f = (f - mean) / (std + self.eps)
        logits = self.classifier(f)
        return logits

    def forward_domain_discriminator(self, f, normalized_f, previous_mean_logStd):
        b, d = f.size()
        l = previous_mean_logStd.size(0)
        for i in range(0, l):
            previous_mean_i = previous_mean_logStd[i, :self.feature_dim]
            previous_logStd_i = previous_mean_logStd[i, self.feature_dim:]
            std_i = torch.exp(previous_logStd_i)

            if self.args.truncate is not None:
                w = torch.empty(b, 1).cuda()
                nn.init.trunc_normal_(w, mean=0.0, std=1.0, a=-self.args.truncate, b=self.args.truncate)
            else:
                distri = torch.distributions.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([1.0]))
                w = distri.sample((b,)).cuda()

            temp_features_i = previous_mean_i.view(1, d).expand(b, d) + w.expand(b, d) * std_i.view(1, d).expand(b, d)

            if i == 0:
                previous_f = temp_features_i.detach()
            else:
                previous_f = torch.cat((previous_f, temp_features_i.detach()), dim=0)
        transfer_loss = self.domain_adv(previous_f, f)
        return transfer_loss

    def rest_discriminator_lr(self, max_iters):
        self.domain_adv.reset(max_iters=max_iters)

    def get_parameters(self, lr):
        params_list = []
        for k in range(1, self.args.scale + 1):
            params_list.append({"params": self.attention_dict[k].parameters(), 'lr': 1 * lr})
        params_list.extend([
                {"params": self.enc.parameters(), 'lr': lr},
                {"params": self.classifier.parameters(), 'lr': 1 * lr},
                {"params": self.init_pool[0], 'lr': 1 * lr},
                {"params": self.init_pool[1], 'lr': 1 * lr},
                {"params": self.domain_discriminator.parameters(), 'lr': 1 * lr},
            ]
        )
        if self.args.dim_bottleneck_f is not None:
            params_list.append({"params": self.bottleneck.parameters(), 'lr': 1 * lr})
        return params_list




