import json
import math
import multiprocessing
from multiprocessing import shared_memory
import torch.multiprocessing as mp
import os
import random
import time
from cmath import isnan
from collections import defaultdict
from copy import deepcopy
from multiprocessing import Manager, Process, Value, Queue
from pickletools import optimize
from posixpath import split

import numpy as np
import torch
import torch.nn.functional as F
from torch import autograd, nn
from torch.autograd import Variable
from torch.nn import DataParallel

from .loss_function import *
from .models import CosClassifier
from .models.utils.load_parameters import (load_ckpt,
                                                model_load_state_dict,
                                                save_ckpt)
from .models.utils.lr_strategies import get_lr_strategy
from .utils import may_make_dir, to_scalar, calculate_cos_affinity, calculate_single_cos_affinity, norm_2
from .utils.meters import AverageMeter
from .utils.smooth_label import smooth_label


class Memory_Trainer(object):
    def __init__(self, backbone_model, data_loader, tester, TVT, TMO, device_ids=None):
        self.device_ids = device_ids
        self.backbone_model = backbone_model
        self.data_loader = data_loader
        self.tester = tester
        self.TVT = TVT
        self.TMO = TMO
        self.classfier = None
        self.epoch_done = False
        self.class_number = self.data_loader.class_number
        self.memo_col = self.backbone_model.module.feature_in_dim
        self.feature_in_dim = self.backbone_model.module.feature_in_dim
        # <== Set mse loss
        mse_kwargs = dict(
            reduction='mean'
        )
        self.mse_loss = get_loss_func('MSEloss', **mse_kwargs)
        # <==
        self.step_star_time = time.time()
        self.epoch_start_time = time.time()
        self.backbone_f_loss = AverageMeter()
        self.backbone_l_loss = AverageMeter()
        self.backbone_mse_loss = AverageMeter()
        # initial memory manager
        self._init_memory_manager()
        self.R_dict = dict()

    def _hug_matrix_cos_affinity(self, matrix_1, matrix_2, batchsize=2**14, del_diagonal=False):
        '''
        NOTE:
        input: matrix_1 -> n,c,h,w; matrix_2 -> b,c,h,w
        output: affinity -> n,b
        the output are two torch tensors
        '''
        b, c = matrix_2.shape
        affiinity = []
        _eye_matrix = self.TVT(torch.eye(b, b).byte())
        split_eye_matrix = torch.split(_eye_matrix, batchsize)
        split_matrix_1 = torch.split(matrix_1, batchsize)
        ### low memory function ###
        for i, matrix_i in enumerate(split_matrix_1):
            affiinity.clear()
            # metrix_i: BS, C
            if len(matrix_2)>batchsize:
                split_matrix_i_2 = torch.split(matrix_2, batchsize)
                for _matrix_i_2 in split_matrix_i_2:
                    matrix_i_2_affi = calculate_cos_affinity(matrix_i, _matrix_i_2)
                    # matrix_i_2_affi: BS, BS or smaller
                    affiinity.append(matrix_i_2_affi)
                # matrix_i_2_affi: BS, b
                matrix_i_2_affi = torch.cat(affiinity, -1)
            else:
                matrix_i_2_affi = calculate_cos_affinity(matrix_i, matrix_2)
            if del_diagonal:
                matrix_i_2_affi -= split_eye_matrix[i]
                
        return matrix_i_2_affi

    def _init_memory(self, memory_file=None):
        if memory_file:
            memory_files = np.load(memory_file)
            tmp_memory = memory_files['memory']
            tmp_initial_memo = memory_files['initial_memory']
            tmp_memory_mean = memory_files['memory_mean']
            self.memory = self.TVT(torch.from_numpy(tmp_memory))
            self.initial_memory = self.TVT(torch.from_numpy(tmp_initial_memo))
            self.memory_mean = self.TVT(torch.from_numpy(tmp_memory_mean))
        else:
            self.memory = self.TVT(torch.zeros(self.class_number, self.memo_col))
            self.initial_memory = self.TVT(torch.ones(self.class_number).bool())
            self.memory_mean = self.TVT(torch.zeros(self.class_number, 100))

    def _save_memory(self, path=None):
        memory = self.memory.data.cpu()
        initial_memory = self.initial_memory.data.cpu()
        memory_mean = self.memory_mean.data.cpu()
        path = path if path is not None else self.memory_file
        np.savez(path, memory=np.array(memory), initial_memory=np.array(initial_memory), memory_mean=np.array(memory_mean))

    def _clear_current_memory(self):
        task_label = self.data_loader.current_task_label
        clear_memory = self.TVT(torch.zeros(len(task_label), self.memo_col))
        clear_memory_mean = self.TVT(torch.zeros(len(task_label), 100))
        self.memory[task_label] = clear_memory
        self.memory_mean[task_label] = clear_memory_mean
        self.initial_memory[task_label] = True

    def _init_memory_manager(self):
        self.manager = mp.Manager()
        self.memo_index_num = mp.Value('i', -1)
        self.feature_number = mp.Value('i', -1)
        self.trans_feature_dict = self.manager.dict()
        self.lock = self.manager.Lock()
        self.indx_memo_l = self.manager.list()
        self.all_process = None
        self.event = self.manager.Event()
        self.event.clear()

    def _reset_memory_manager(self):
        self.memo_index_num.value = -1
        self.trans_feature_dict.clear()
        self.indx_memo_l[:] = []
        self.event.clear()

    def _clear_record(self):
        self.epoch_done = False
        self.step = 0
        self.backbone_f_loss.clear()
        self.backbone_l_loss.clear()

    def _load_parameters(self, optimizer, module_optimizer, model_path):
        old_lr = [i['lr'] for i in optimizer.param_groups]
        self.start_epoch = load_ckpt(module_optimizer, model_path)
        for i, _ in enumerate(optimizer.param_groups):
            optimizer.param_groups[i]['lr'] = old_lr[i]

    def _get_parameters(self, model, include='', not_include=''):
        if include == '' and not_include == '':
            return [p.detach().cpu() for p in model.parameters()]
        if include != '':
            return [p.detach().cpu() for n, p in model.named_parameters() if include in n]
        if not_include != '':
            return [p.detach().cpu() for n, p in model.named_parameters() if not_include not in n]

    def _update_memory(self):
        """update the memory bank."""
        raise NotImplementedError

    def get_R_matrxi(self, skip_current_task=True):
        """Get a rotation matrix by multiprocess."""
        raise NotImplementedError

    def release_memory(self, feature_number):
        """release the samples from memory bank"""
        raise NotImplementedError

    def repeat_memory(self, feature_number):
        """repeat the prototype from memory bank as samples"""
        raise NotImplementedError