from math import dist
import os
import sys
from socket import EAI_SOCKTYPE
import torch
import torch.nn as nn
import torch.nn.functional as F
import tensorflow as tf
import numpy as np
import pandas as pd

from tqdm import tqdm
from tabulate import tabulate
from utils import check_dir, device
from config import args

from models.cka import kernel_HSIC
from models.augmentation import DataAugmentation
from models.model_helpers import get_model
from models.losses import cross_entropy_loss
from models.model_utils import (CheckPointer)

from data.meta_dataset_reader import (MetaDatasetEpisodeReader, TRAIN_METADATASET_NAMES, ALL_METADATASET_NAMES)

tf.compat.v1.disable_eager_execution()


def get_backbone():
    # Load pretrained backbone
    backbone = get_model(None, args)
    backbone_checkpointer = CheckPointer(args, backbone, optimizer=None)
    backbone_checkpointer.restore_model(ckpt='best', strict=False)
    backbone.eval()
    return backbone


def get_optimizer(model, atten_lr, head_lr, weight_decay):
    return torch.optim.Adadelta(
        [{'params': model.query_head.parameters(), 'lr':atten_lr},
         {'params': model.key_head.parameters(), 'lr':atten_lr},
         {'params': model.value_head.parameters(), 'lr':atten_lr},
         {'params': model.transform_head.parameters()}], 
        lr=head_lr, weight_decay = weight_decay)


def compute_prototypes(embeddings:torch.Tensor, labels:torch.Tensor):
    '''
    Args:
        embeddings: [n_embeddings, c, h, w]
        labels: [n_embeddings, ]
    '''
    unique_labels = torch.range(start=0, end=torch.max(labels)).unsqueeze(dim=1).type_as(labels)    # [n_cls, 1]
    matrix = unique_labels.eq(labels.reshape(1, list(labels.shape)[0])).type_as(embeddings)
    flatten_prototypes = torch.matmul(matrix, embeddings.flatten(1)) / matrix.sum(dim=1, keepdim=True)
    _, c, h, w = list(embeddings.shape)
    prototypes = torch.reshape(flatten_prototypes, shape=(list(flatten_prototypes.shape)[0], c, h, w))
    return prototypes


class AttentionHead(nn.Module):
    def __init__(self, args, in_dim:int=512) -> None:
        '''
        Args:
            out_dim: output dimension of the head layer;
            in_dim: input dimension of the head layer;
            typical_atten: Whether to use typical attention modules, which includes key, query and values layer;
        '''
        super(AttentionHead, self).__init__()
        self.in_dim = in_dim
        self.out_dim = args['out_dim']
        
        self.query_head = nn.Conv2d(in_channels=self.in_dim, out_channels=self.out_dim, kernel_size=1, stride=1, bias=False)
        self.key_head = nn.Conv2d(in_channels=self.in_dim, out_channels=self.out_dim, kernel_size=1, stride=1, bias=False)
        self.value_head = nn.Conv2d(in_channels=self.in_dim, out_channels=self.out_dim, kernel_size=1, stride=1, bias=False)
        
        self.bn = nn.BatchNorm2d(self.out_dim)
        
        self.avgpool_fn = F.adaptive_avg_pool2d
        self.transform_head = nn.Conv2d(in_channels=self.out_dim, out_channels=self.out_dim, kernel_size=1, stride=1, bias=False)

    def reset_params(self) -> None:
        self.query_head.weight = nn.Parameter(torch.eye(self.out_dim, self.in_dim).unsqueeze(-1).unsqueeze(-1)*args['gain'])
        self.key_head.weight = nn.Parameter(torch.eye(self.out_dim, self.in_dim).unsqueeze(-1).unsqueeze(-1)*args['gain'])
        self.value_head.weight = nn.Parameter(torch.eye(self.out_dim, self.in_dim).unsqueeze(-1).unsqueeze(-1)*args['gain'])
        nn.init.constant_(self.bn.weight, 1.)
        nn.init.constant_(self.bn.bias, 0.)
        self.transform_head.weight = nn.Parameter(torch.eye(self.out_dim, self.out_dim).unsqueeze(-1).unsqueeze(-1))

    def forward_pass(self, context_x, context_y, aug_context) -> torch.Tensor:
        context_queries = self.query_head(context_x)
        keys = self.key_head(aug_context)
        values = self.value_head(aug_context)

        context_reconst = self.compute_attention(queries=context_queries, keys=keys, values=values)
        context_features = context_x + args['scale_reconst']*context_reconst   # original features fusion
        context_features = self.bn(context_features)
        context_features = self.transform_head(self.avgpool_fn(context_features, (1, 1)))
        
        prototypes = compute_prototypes(context_features, context_y) # [n_classes, c, h, w]
        
        dist_res = F.cosine_similarity(context_features.flatten(1).unsqueeze(1), 
                                         prototypes.flatten(1).unsqueeze(0),
                                         dim=-1,
                                         eps=1e-30)*10
        return dist_res


    def pred(self, target_x:torch.Tensor, aug_context:torch.Tensor, context_x:torch.Tensor, context_y:torch.Tensor) -> torch.Tensor:
        context_queries = self.query_head(context_x)
        target_queries = self.query_head(target_x)
        keys = self.key_head(aug_context)
        values = self.value_head(aug_context )
        
        context_reconst = self.compute_attention(queries=context_queries, keys=keys, values=values)
        context_features = context_x + args['scale_reconst']*context_reconst  # original feature fusion
        context_features = self.bn(context_features)
        context_features = self.transform_head(self.avgpool_fn(context_features, (1, 1)))
        
        target_reconst = self.compute_attention(queries=target_queries, keys=keys, values=values)
        target_features = target_x + args['scale_reconst']*target_reconst     # original feature fusion
        target_features = self.bn(target_features)
        target_features = self.transform_head(self.avgpool_fn(target_features, (1, 1)))
        
        prototypes = compute_prototypes(context_features, context_y)
        
        dist_res = F.cosine_similarity(target_features.flatten(1).unsqueeze(1), 
                                         prototypes.flatten(1).unsqueeze(0),
                                         dim=-1,
                                         eps=1e-30)*10
        
        return dist_res
    
    def compute_attention(self, queries, keys, values) -> torch.Tensor:
        n_q, c_q, h_q, w_q = list(queries.shape)
        n_v, c_v, h_v, w_v = list(values.shape)

        flatten_queries = torch.reshape(queries, shape=(-1, c_q))    # [n_supp*h*w, c]
        flatten_keys = torch.reshape(keys, shape=(-1, c_v))             # [n_clusters*h*w, c]
        flatten_values = torch.reshape(values, shape=(-1, c_v))
        d_scale = torch.rsqrt(torch.tensor(self.out_dim).type_as(queries)).to(device)

        inner_prod = torch.matmul(flatten_queries, flatten_keys.t())    # [n_supp*h*w, n_clusters*h*w]
        inner_logits = d_scale * inner_prod

        inner_logits = torch.reshape(inner_logits, shape=(n_q, h_q*w_q, n_v*h_v*w_v))
        max_logits, _ = torch.max(inner_logits, dim=1, keepdim=True)
        inner_logits = inner_logits - max_logits

        exp_logits = torch.exp(inner_logits)
        softmax_logits = exp_logits / torch.sum(exp_logits, dim=1, keepdim=True)

        res = torch.matmul(softmax_logits.reshape((n_q*h_q*w_q, n_v*h_v*w_v)), flatten_values)
        return res.reshape(shape=queries.shape)


if __name__ == '__main__':

    # Prepare data
    if args['data.train_source'] == 'singlesource':
        trainsets = args['data.train']
    elif args['data.train_source'] == 'multisource':
        trainsets = TRAIN_METADATASET_NAMES
    else:
        raise ValueError("Unrecognized key word.")
    
    testsets = ALL_METADATASET_NAMES
    testdataloader = MetaDatasetEpisodeReader(mode='test', test_set=testsets,
                                              test_type=args['test.type'])
    
    # Initialize models & objects of classes
    backbone = get_backbone()
    attention_head = AttentionHead(args)
    data_aug_generator = DataAugmentation(args['num_aug'])

    accs_names = ['NCC']

    train_var_accs = dict()
    var_accs = dict()

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth=False
    
    print(f"================= Learning on {args['experiment_name']} Starts! =================")
    
    with tf.compat.v1.Session(config=config) as session:
        for dataset in testsets:
            if dataset in ['traffic_sign', 'mnist']:
                atten_lr = 1.0
                head_lr = 1.0
            else:
                atten_lr = 0.1
                head_lr = 0.1

            weight_decay = args['weight_decay']
            max_inner_iter = args['inner_iter']

            print(dataset)
            train_var_accs[dataset] = {name:[] for name in accs_names}
            var_accs[dataset] = {name:[] for name in accs_names}

            for i in tqdm(range(600)):
                with torch.no_grad():
                    sample = testdataloader.get_test_task(session, dataset)
                    context_features = backbone.embed(sample['context_images'], is_pooling=False)
                    aug_context_features = backbone.embed(data_aug_generator.generate_augmentations(sample['context_images']), is_pooling=False)
                    target_features = backbone.embed(sample['target_images'], is_pooling=False)
                    context_labels = sample['context_labels']
                    target_labels = sample['target_labels']
                # reset the parameters and send them to cuda
                attention_head.reset_params()
                attention_head.to(device)

                # renew optimizer
                optimizer = get_optimizer(model=attention_head, 
                                          atten_lr=atten_lr, 
                                          head_lr=head_lr, 
                                          weight_decay=weight_decay)

                for j in range(max_inner_iter):
                    
                    attention_head.train()
                    optimizer.zero_grad()

                    logits = attention_head.forward_pass(context_x=context_features,
                                                         context_y=context_labels,
                                                         aug_context=aug_context_features)
                    loss, train_stats, _ = cross_entropy_loss(logits=logits, targets=context_labels)
                    total_losses = loss
                    total_losses.backward()
                    optimizer.step()

                    if j == max_inner_iter - 1:
                        attention_head.eval()
                        with torch.no_grad():
                            val_res = attention_head.pred(target_x=target_features,
                                                        aug_context=aug_context_features,
                                                        context_x=context_features,
                                                        context_y=context_labels)
                            _, tmp_val_stats, _ = cross_entropy_loss(logits=val_res, targets=target_labels)

                # eval query data
                train_var_accs[dataset]['NCC'].append(train_stats['acc'])
                var_accs[dataset]['NCC'].append(tmp_val_stats['acc'])

            train_acc = np.array(train_var_accs[dataset]['NCC'])*100
            dataset_acc = np.array(var_accs[dataset]['NCC'])*100
            print(f"{dataset}: train_acc {train_acc.mean():.2f}%; test_acc {dataset_acc.mean():.2f} +/- {(1.96*dataset_acc.std()) / np.sqrt(len(dataset_acc)):.2f}%")
    
    print('results of {} with P%{}'.format(args['model.name'], args['headmodel.name']))
    rows = []
    for dataset_name in testsets:
        row = [dataset_name]
        for model_name in accs_names:
            acc = np.array(var_accs[dataset_name][model_name])*100
            mean_acc = acc.mean()
            conf = (1.96*acc.std()) / np.sqrt(len(acc))
            row.append(f"{mean_acc:0.2f} +/- {conf:0.2f}%")
        rows.append(row)
    outpath = os.path.join(args['out.dir'], 'weights')
    outpath = check_dir(outpath, True)
    outpath = os.path.join(outpath, '{}-sslattention-{}-test_results.npy'.format(args['model.name'], args['headmodel.name']))
    np.save(outpath, {'rows':rows})

    table = tabulate(rows, headers=['model \\ data'] + accs_names, floatfmt=".2f")
    print(table)
    print("\n")
    print(f"{args['experiment_name']} Done!")
