import timeit
import gc
import torch
import torch.backends.cudnn as cudnn
from thop import profile
import sklearn.metrics as metrics
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
# from models.model import cal_loss

def lat_measure(shape, model, num_runs, device=torch.device('cuda:0')):
    input_batch = torch.rand(shape)

    if device.type == 'cuda':
        if torch.cuda.is_available():
            #cudnn.benchmark = True
            input_batch = input_batch.to(device)
            # print("input_batch.device1",input_batch.device)
            model.to(device)
            # print("model.device_ids1",model.device_ids)
    else:
        print(device.type)
        raise Exception('GPU is not available.')

    globals()['model'], globals()['input_batch'] = model, input_batch
    dummy_runs = 5
    for _ in range(dummy_runs):
        output = model(input_batch)
        # print("output:",output)
    total_lat = timeit.timeit('output = model(input_batch); torch.cuda.synchronize()',
                              setup='gc.enable()',
                              number=num_runs,
                              globals={'model': globals().get('model'),
                                       'input_batch': globals().get('input_batch'),
                                       'torch': globals().get('torch'),
                                       'gc': globals().get('gc')})

    avg_lat = total_lat / num_runs * 1e3

    return avg_lat

def lat_acc(data,seg, model,  device=torch.device('cuda:0')):
    # input_batch = torch.rand(data)
    train_true_cls = []
    train_pred_cls = []
    train_true_seg = []
    train_pred_seg = []
    if device.type == 'cuda':
        if torch.cuda.is_available():
            #cudnn.benchmark = True
            data = data.to(device)
            # input_batch = input_batch.to(device)
            model.to(device)
    else:
        print(device.type)
        raise Exception('GPU is not available.')
    
    # def f(m, x, y):
    #     m.total_ops += torch.Tensor([int(0)])
    # globals()['model'], globals()['data'] = model, data
    # globals()['model'], globals()['input_batch'] = model, input_batch
    # batch_size = data.size()[0]
    seg_pred = model(data)
    # seg_pred = model(input_batch)
    seg_pred = seg_pred.permute(0, 2, 1).contiguous() #torch.Size([16, 4096, 6])
    pred = seg_pred.max(dim=2)[1]
    seg_np = seg.cpu().numpy()                  # (batch_size, num_points)
    pred_np = pred.detach().cpu().numpy()       # (batch_size, num_points)
    train_true_cls.append(seg_np.reshape(-1))       # (batch_size * num_points)
    train_pred_cls.append(pred_np.reshape(-1))      # (batch_size * num_points)
    train_true_cls = np.concatenate(train_true_cls)
    train_pred_cls = np.concatenate(train_pred_cls)
    train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls)
    avg_per_class_acc = metrics.balanced_accuracy_score(train_true_cls, train_pred_cls)
    
    return  train_acc, avg_per_class_acc

def lat_acc_iter(trainDataLoader, model,  device=torch.device('cuda:0')):
    # input_batch = torch.rand(data)
    opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(opt, 5, eta_min=1e-3)
    criterion = cal_loss
    for epoch in range(1):
        train_loss = 0.0
        count = 0.0
        train_true_cls = []
        train_pred_cls = []
        train_true_seg = []
        train_pred_seg = []
        if device.type == 'cuda':
            if torch.cuda.is_available():
                #cudnn.benchmark = True
                
                # input_batch = input_batch.to(device)
                model.to(device)
        else:
            print(device.type)
            raise Exception('GPU is not available.')
        for data, seg in trainDataLoader:
            data, seg = data.to(device), seg.to(device)  #seg torch.Size([16, 4096])  
            data = data.permute(0, 2, 1) #torch.Size([16, 6, 4096])
            data = data.type(torch.cuda.FloatTensor)
            seg = seg.to(dtype=torch.int64) #torch.Size([16, 4096])
            # print("###seg:",seg.dtype)
            batch_size = data.size()[0]
            opt.zero_grad()
            seg_pred = model(data)
            seg_pred = seg_pred.permute(0, 2, 1).contiguous() #torch.Size([16, 4096, 6])
            loss = criterion(seg_pred.view(-1, 6), seg.view(-1,1).squeeze())
            loss.backward()
            opt.step()
            pred = seg_pred.max(dim=2)[1]               # (batch_size, num_points)
            count += batch_size
            train_loss += loss.item() * batch_size
            seg_np = seg.cpu().numpy()                  # (batch_size, num_points)
            pred_np = pred.detach().cpu().numpy()       # (batch_size, num_points)
            train_true_cls.append(seg_np.reshape(-1))       # (batch_size * num_points)
            train_pred_cls.append(pred_np.reshape(-1))      # (batch_size * num_points)
            train_true_seg.append(seg_np)
            train_pred_seg.append(pred_np)
            scheduler.step()
        train_true_cls = np.concatenate(train_true_cls)
        train_pred_cls = np.concatenate(train_pred_cls)
        train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls)
        avg_per_class_acc = metrics.balanced_accuracy_score(train_true_cls, train_pred_cls)
        
    return  train_acc, avg_per_class_acc


def macs_param_measure(shape, model,  device=torch.device('cuda:0')):
    input_batch = torch.rand(shape)

    if device.type == 'cuda':
        if torch.cuda.is_available():
            input_batch = input_batch.to(device)
            print("input_batch.device2",input_batch.device)
            model.to(device)
            # print("model.device_ids2",model.device_ids)
    else:
        print(device.type)
        raise Exception('GPU is not available.')

    def f(m, x, y):
        m.total_ops += torch.Tensor([int(0)])

    macs, params = profile(model, inputs=(input_batch, ),
                           verbose=False)

    return macs, params


    
    
    