from collections import defaultdict
from time import time
import torch
from torchvision import transforms
from fvcore.nn import FlopCountAnalysis, flop_count_table
import torch.nn as nn
from torch.nn import init
from tqdm import tqdm

def report_sample_by_class(data_loader):
    if data_loader is None: return None
    size = defaultdict(int)
    for batch, label in tqdm(data_loader):
        for l in label:
            size[l.item()] += 1
    return size
    # it is not efficient to iterate through the entire dataset, so we will only consider labels not images

def get_logits(logits):
    try: return logits.logits
    except: return logits

class Statistics:
    def __init__(self, model, args):
        self.model = model
        if args.model_name == 'ViT': # 224x224
            mock_sample = torch.randn(1, 3, 224, 224).to(args.device)
        else: mock_sample = torch.randn(1, 3, args.image_size, args.image_size).to(args.device)

        self.flops = FlopCountAnalysis(model, mock_sample).total()

        self.start_time = 0
        self.elapsed_time = 0
        self.total_flops = 0

    def start_record(self):
        self.start_time = time()
        self.total_flops = 0

    def end_record(self):
        self.elapsed_time = time() - self.start_time

    def add_forward_flops(self, num_samples):
        self.total_flops += self.flops * num_samples
        
    def add_backward_flops(self, num_samples):
        self.total_flops += 2 * self.flops * num_samples

    def add_matrix_multiplication_flops(self, A, B):
        assert A.shape[1] == B.shape[0]
        m, n, p = A.shape[0], B.shape[1], B.shape[0]
        self.total_flops += m * n * (2 * p - 1)

    def add_flops_manual(self, flops):
        self.total_flops += flops

class TwoTransform():
    def __init__(self, transform=None, n_views=2):
        self.transform = transform
        self.n_views = n_views
    
    def __call__(self, x):
        return [self.transform(x) for _ in range(self.n_views)]
        

# for efficient reset or retrain of the model
def getRetrainLayers(m, name, ret):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
        ret.append((m, name))
    for child_name, child in m.named_children():
        getRetrainLayers(child, f'{name}.{child_name}', ret)
    return ret

def resetFinalResnet(model, num_retrain, reinit=True):
    for param in model.parameters():
        param.requires_grad = False

    done = 0
    ret = getRetrainLayers(model, 'M', [])
    ret.reverse()
    for idx in range(len(ret)):
        if reinit:
            if isinstance(ret[idx][0], nn.Conv2d) or isinstance(ret[idx][0], nn.Linear):
                _reinit(ret[idx][0])
        if isinstance(ret[idx][0], nn.Conv2d) or isinstance(ret[idx][0], nn.Linear):
            done += 1
        for param in ret[idx][0].parameters():
            param.requires_grad = True
        if done >= num_retrain:
            break

    return model

def _reinit(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)