import math

import numpy as np
import torch
import torchvision.transforms.functional as F
from torch import nn

from . import cross_attn, models


def init_algorithm(args, train_dataset):

    if args.dataset == 'femnist':
        num_classes = 62
        num_train_domains = 262
        n_img_channels = 1
        spatial_dim = 4
        embed_dim = 128
        img_shape = 28
    elif args.dataset in 'tinyimg':
        num_classes = 200
        num_train_domains = 51
        n_img_channels = 3
        spatial_dim = 1
        embed_dim = 2048
        img_shape = 64
    elif args.dataset in 'cifar-c':
        num_classes = 10
        num_train_domains = 56
        n_img_channels = 3
        spatial_dim = 5
        embed_dim = 128
        img_shape = 32

    if args.algorithm == 'ARM-CML':
        n_channels = n_img_channels + args.n_context_channels
        hidden_dim = 64
        context_net = models.ContextNet(n_img_channels, args.n_context_channels,
                                        hidden_dim=hidden_dim, kernel_size=5).to(args.device)
    elif args.algorithm == 'CXDA':
        n_channels = n_img_channels

        if args.model == 'convnet':
            cross_attention = cross_attn.BlockCXDA(
                dim=128, spatial_dim=spatial_dim, supervised=args.supervised).to(args.device)
            classifier = models.MLPClassifier(
                num_classes=num_classes).to(args.device)
        else:
            cross_attention = cross_attn.BlockCXDA(
                dim=embed_dim, spatial_dim=spatial_dim, supervised=args.supervised).to(args.device)
            classifier = models.MLPClassifierForResNet(
                hidden_dim=embed_dim, num_classes=num_classes).to(args.device)
    else:
        n_channels = n_img_channels

    if args.algorithm == 'CXDA':
        return_2d_features = True
    else:
        return_2d_features = False

    # Main model
    if args.model == 'convnet':
        model = models.ConvNet(num_channels=n_channels, num_classes=num_classes,
                               smaller_model=False, return_features=False, return_2d_features=return_2d_features)
    else:
        model = models.ResNet(num_channels=n_channels, num_classes=num_classes, model_name=args.model,
                              pretrained=args.pretrained, return_features=False, return_2d_features=return_2d_features)
    model = model.to(args.device)

    loss_fn = nn.CrossEntropyLoss()

    # Algorithm
    hparams = {'optimizer': args.optimizer,
               'learning_rate': args.learning_rate,
               'weight_decay': args.weight_decay,
               'support_size': args.support_size
               }

    if args.algorithm == 'ERM':
        algorithm = ERM(model, loss_fn, args.device, hparams)

    elif args.algorithm == 'ARM-CML':
        hparams['n_context_channels'] = args.n_context_channels
        algorithm = ARM_CML(model, loss_fn, args.device, context_net, hparams)

    elif args.algorithm == 'ARM-BN':
        algorithm = ARM_BN(model, loss_fn, args.device, hparams)

    elif args.algorithm == 'CXDA':
        algorithm = CXDA(model, loss_fn, args.device,
                         classifier, cross_attention, hparams)

    return algorithm


class AverageMeter:

    def __init__(name):
        self.value = 0
        self.total_count = 0

    def update(self, value, count):
        old_count = self.total_count
        new_count = new_count + count

        self.value = self.value * old_count / new_count + value * count / new_count
        self.total_count = new_count

    def reset(self):
        self.value = 0
        self.count = 0


class ERM(nn.Module):
    def __init__(self, model, loss_fn, device, hparams, init_optim=True, **kwargs):
        super().__init__()
        self.res_multiplier = 0
        self.model = model
        self.loss_fn = loss_fn
        self.device = device

        self.optimizer_name = hparams['optimizer']
        self.learning_rate = hparams['learning_rate']
        self.weight_decay = hparams['weight_decay']
        self.support_size = hparams['support_size']

        if init_optim:
            params = self.model.parameters()
            self.init_optimizers(params)

    def init_optimizers(self, params):
        if self.optimizer_name == 'adam':
            self.optimizer = torch.optim.Adam(params, lr=self.learning_rate,
                                              weight_decay=self.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, params),
                lr=self.learning_rate,
                momentum=0.9,
                weight_decay=self.weight_decay)

    def predict(self, x, group_ids=None, train=True):
        batch_size, c, h, w = x.shape
        # Split x into support and query examples
        # There is a structure in how the examples are ordered
        # and we unpack it here
        support_set_idx = [e for e in range(batch_size) if (
            e // self.support_size) % 2 == 0]
        query_set_idx = [e for e in range(batch_size) if (
            e // self.support_size) % 2 == 1]
        support_set = x[support_set_idx]
        # Query examples are sampled from one domain
        # Since the order of domains is randomly generated,
        # we can use the first one here
        # In the query set of examples the domains come
        # after each other
        # Similar logic is used in the other approaches
        query_set = x[query_set_idx][:self.support_size]

        return self.model(query_set)

    def update(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def get_acc(self, logits, labels):
        # Evaluate
        preds = np.argmax(logits.detach().cpu().numpy(), axis=1)
        accuracy = np.mean(preds == labels.detach().cpu().numpy().reshape(-1))
        return accuracy

    def learn(self, images, labels, group_ids=None):

        self.train()

        logits = self.predict(images, group_ids=group_ids)
        loss = self.loss_fn(logits, labels)
        self.update(loss)

        stats = {
            'objective': loss.detach().item(),
        }

        return logits, stats


class ARM_CML(ERM):

    def __init__(self, model, loss_fn, device, context_net, hparams={}):
        super().__init__(model, loss_fn, device, hparams)

        self.context_net = context_net
        self.support_size = hparams['support_size']
        self.n_context_channels = hparams['n_context_channels']

        params = list(self.model.parameters()) + \
            list(self.context_net.parameters())
        self.init_optimizers(params)

    def predict(self, x, group_ids=None, train=True):
        batch_size, c, h, w = x.shape

        support_set_idx = [e for e in range(batch_size) if (
            e // self.support_size) % 2 == 0]
        query_set_idx = [e for e in range(batch_size) if (
            e // self.support_size) % 2 == 1]
        support_set = x[support_set_idx]
        query_set = x[query_set_idx][:self.support_size]

        # Process the support set using context net and take average
        # across the whole set to create the context
        context = self.context_net(support_set).mean(dim=0).unsqueeze(
            dim=0)
        
        # Concatenate the context as additional channels to the model
        context = context.repeat(query_set.shape[0], 1, 1, 1)
        x = torch.cat([query_set, context], dim=1)

        return self.model(x)


class ARM_BN(ERM):

    def __init__(self, model, loss_fn, device, hparams={}):
        super().__init__(model, loss_fn, device, hparams)

        self.support_size = hparams['support_size']

    def predict(self, x, group_ids=None, train=True):
        # Adapt the BN statistics of the feature extractor
        # using the support set
        self.model.train()

        batch_size, c, h, w = x.shape

        support_set_idx = [e for e in range(batch_size) if (
            e // self.support_size) % 2 == 0]
        query_set_idx = [e for e in range(batch_size) if (
            e // self.support_size) % 2 == 1]
        support_set = x[support_set_idx]
        query_set = x[query_set_idx][:self.support_size]

        _ = self.model(support_set)

        # We will process the query examples next,
        # so if we do evaluation, we freeze BN statistics
        if not train:
            self.model.eval()

        return self.model(query_set)


class CXDA(ERM):

    def __init__(self, model, loss_fn, device, classifier, cross_attention, hparams={}):
        super().__init__(model, loss_fn, device, hparams)

        self.classifier = classifier
        self.cross_attention = cross_attention
        self.support_size = hparams['support_size']

        params = list(self.model.parameters(
        )) + list(self.classifier.parameters()) + list(self.cross_attention.parameters())

        self.init_optimizers(params)

    def predict(self, x, group_ids=None, return_attention=False, train=True):
        # Adapt the BN statistics of the feature extractor
        # using the support set
        self.model.train()
        batch_size, c, h, w = x.shape

        support_set_idx = [e for e in range(batch_size) if (
            e // self.support_size) % 2 == 0]
        query_set_idx = [e for e in range(batch_size) if (
            e // self.support_size) % 2 == 1]
        support_set = x[support_set_idx]
        query_set = x[query_set_idx][:self.support_size]

        support_group_ids = group_ids[support_set_idx]
        query_group_ids = group_ids[query_set_idx][:self.support_size]

        # Extract features from the support set
        context_emb = self.model(support_set)

        # We will process the query examples next,
        # so if we do evaluation, we freeze BN statistics
        if not train:
            self.model.eval()

        # Extract features from the query set
        x_emb = self.model(query_set)

        if return_attention:
            attn = self.cross_attention(
                x_emb, context_emb, train=train, return_attention=return_attention)
            return attn

        # Do cross-attention between the query example features and support example features
        out = self.cross_attention(x_emb, context_emb, train=train, return_attention=return_attention,
                                   support_group_ids=support_group_ids, query_group_ids=query_group_ids)
        # Make a prediction
        out = self.classifier(out)

        return out
