import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import copy
import time
import pickle
from torch.utils.data import DataLoader
from skimage.util.shape import view_as_windows
from tqdm import tqdm
from tensorboardX import SummaryWriter
from collections import defaultdict
from dataset import Train_dataset, Test_dataset, Finetune_dataset
from torch.optim.lr_scheduler import LRScheduler
from sklearn.metrics.pairwise import cosine_similarity
from utils import *
from ResNet_cifar import resnet32
import matplotlib.pyplot as plt
import seaborn as sns
from skimage.util.shape import view_as_windows
from sklearn.metrics import confusion_matrix
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def init():
    args = Argument()
    gpu_id = 0
    args.device = torch.device(f"cuda:{gpu_id}" if (torch.cuda.is_available() and (gpu_id >= 0)) else "cpu")
    if args.dataset_name == 'CIFAR10':
        args.data_path = 'cifar-10-python/cifar-10-batches-py'
        args.action_dim = 10
    if args.dataset_name == 'CIFAR100':
        args.data_path = 'cifar-100-python'
        args.action_dim = 100
    args.writer = SummaryWriter('outputs')
    args.init_before_training()
    return args

def load_pretrained_model(args, load_pretrain=True, ckpt_path=None):
    if args.dataset_name in ('CIFAR10', 'CIFAR100'):
        model = resnet32(args.action_dim).cuda()
    if load_pretrain == True:
        model.load_state_dict(torch.load(ckpt_path))
    return model


class GradualWarmupScheduler(LRScheduler):
    def __init__(self, optimizer, total_epoch, after_scheduler=None):
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch >= self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr for base_lr in self.base_lrs]

        return [(base_lr / self.total_epoch) * (self.last_epoch + 1) for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if self.finished and self.after_scheduler:
            if epoch is None:
                self.after_scheduler.step(None)
            else:
                self.after_scheduler.step(epoch - self.total_epoch)
        else:
            return super(GradualWarmupScheduler, self).step(epoch)


def test(data_loader, epoch, model, data_num_N):
    test_data_loader = data_loader
    model.eval()
    total_num = 0
    true_predict = 0
    start_time = time.time()

    d_total = defaultdict(int)
    d_true = defaultdict(int)
    all_targets, all_preds = [], []
    with torch.no_grad():
        for _, batch in enumerate(test_data_loader):
            data, gt = batch[0], batch[1]
            total_num += data.shape[0]
            data = data.cuda().float()
            gt = gt.squeeze().cuda()
            output = model(data)
            predict = np.argmax(output.detach().cpu().numpy(), axis=1)

            all_preds.extend(predict)
            all_targets.extend(gt.detach().cpu().numpy())

            for index, i in enumerate(gt.detach().cpu().numpy()):
                d_total[i] += 1
                if predict[index] == i:
                    d_true[i] += 1
            true_predict += np.sum(predict == gt.detach().cpu().numpy())
    total_accaurcy = 0
    # for i in range(90, 100):
    #     args.writer.add_scalar(f'accuracy class{i}', d_true[i] / d_total[i], epoch)
    #     total_accaurcy += d_true[i] / d_total[i]
    #     print(f'accuracy class{i}: {d_true[i] / (d_total[i] + 1e-8):.4f}')
    # args.writer.add_scalar('accuracy', true_predict / total_num, epoch)
    # args.writer.add_scalar('accuracy', total_accaurcy / 10, epoch)
    print(f'epoch:{epoch}, accuracy:{true_predict / total_num:.4f}, time:{time.time() - start_time:.2f}s')

    cf = confusion_matrix(all_targets, all_preds).astype(float)
    cls_cnt = cf.sum(axis=1)
    cls_hit = np.diag(cf)
    cls_acc = cls_hit / cls_cnt

    many_shot = data_num_N > 100
    medium_shot = (data_num_N <= 100) & (data_num_N > 20)
    few_shot = data_num_N <= 20
    eps = np.finfo(np.float64).eps
    print(f'many avg:{float(sum(cls_acc[many_shot]) / (sum(many_shot) + eps)):.4f}, '
          f'med avg:{float(sum(cls_acc[medium_shot]) / (sum(medium_shot) + eps)):.4f}, '
          f'few avg:{float(sum(cls_acc[few_shot]) / (sum(few_shot) + eps)):.4f}')




def train_and_evaluate(args):
    #####  prepare data   #####
    data_path = args.data_path
    train_data, train_gt, test_data, test_gt, data_num_N = prepare_data(data_path, args.dataset_name, args.imbalance_ratio)
    ''' set head/tail classes '''
    ratio = 0.5                         # more than 0.5*max_sample_size viewed as head class
    max_sample_size = data_num_N[0]
    head_class_index = np.where(data_num_N >= ratio * max_sample_size)[0]
    tail_class_index = np.where(data_num_N < ratio * max_sample_size)[0]

    train_dataset = Train_dataset(train_data, train_gt, args.dataset_name)
    test_dataset = Test_dataset(test_data, test_gt, args.dataset_name)
    '''  origin data loader '''
    train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    ''' weighted data loader '''
    # cls_weight = 1.0 / (data_num_N ** 0.2)    # 0.2 control degree
    # cls_weight = cls_weight / np.sum(cls_weight) * len(data_num_N)
    # samples_weight = np.array([cls_weight[t] for t in train_gt])
    # samples_weight = torch.from_numpy(samples_weight).squeeze()
    # samples_weight = samples_weight.double()
    # weighted_sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
    # train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=weighted_sampler)  # sampler can not work with shffle

    test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True)
    '''  init  '''
    model = resnet32(args.action_dim).cuda()
    criterion = nn.CrossEntropyLoss()
    total_num = np.sum(data_num_N)
    prior_y = [i / total_num for i in data_num_N]

    optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.pretrained_epochs)
    ''' warmup '''
    warmup_scheduler = GradualWarmupScheduler(optimizer, total_epoch=20, after_scheduler=scheduler)
    ''' pretrain '''
    for epoch in range(args.pretrained_epochs):
        model.train()
        epoch_loss = 0.0
        start_time = time.time()
        for iter, batch in enumerate(tqdm(train_data_loader)):
            data, gt = batch[0], batch[1]
            data = data.cuda().float()
            gt = gt.squeeze().long().cuda()
            ''' pretrained stage '''
            output, feature = model.forward_with_feature(data)
            loss = criterion(output, gt)
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # print(f'epoch:{epoch+1}, recon_loss:{epoch_loss/len(train_data_loader)}, time:{time.time()-start_time:.2f}s')
        # args.writer.add_scalar('loss', epoch_loss/len(train_data_loader), epoch)
        warmup_scheduler.step()
        test(test_data_loader, epoch+1, model, data_num_N)


    ''' finetune stage '''
    model = load_pretrained_model(args, load_pretrain=True,
                                  ckpt_path=f'model/param_200_{args.dataset_name}_{args.imbalance_ratio}.pth'
                                  )
    test(test_data_loader, 1, model, data_num_N)
    cls_statistic, feature_per_class, logit_per_class, origin_feature, origin_label, head_feature, head_label = (
        cal_statistic_single(train_data_loader, model, data_num_N, head_class_index))


    topk = cal_sim_topk(cls_statistic, head_class_index, tail_class_index, k=5)
    topk_oppo = cal_sim_topk_logit(logit_per_class, tail_class_index, k=15)
    weight_diff = extract_weight_diff(model, tail_class_index, topk_oppo)

    finetune_feature, finetune_label, _, _ = risk_bounded_calibrate(cls_statistic, topk, feature_per_class,
                                                       head_class_index, tail_class_index, data_num_N, weight_diff)

    '''  finetune  '''
    finetune_dataset = Finetune_dataset(finetune_feature, finetune_label)
    finetune_data_loader = DataLoader(finetune_dataset, batch_size=args.batch_size, shuffle=True)


    finetune_optimizer = torch.optim.SGD(model.fc_cb.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-3)
    finetune_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(finetune_optimizer, T_max=args.finetune_epochs)
    ''' warmup '''
    warmup_scheduler = GradualWarmupScheduler(finetune_optimizer, total_epoch=5, after_scheduler=finetune_scheduler)
    for epoch in range(args.finetune_epochs):
        for iter, batch in enumerate(tqdm(finetune_data_loader)):
            data, gt = batch[0], batch[1]
            data = data.cuda().float()
            gt = gt.squeeze().long().cuda()
            output = model.fc_cb(data)
            loss = criterion(output, gt)

            finetune_optimizer.zero_grad(set_to_none=True)
            loss.backward()
            finetune_optimizer.step()

        # finetune_scheduler.step()
        warmup_scheduler.step()
        print(finetune_optimizer.param_groups[0]['lr'])
        test(test_data_loader, epoch + 1, model, data_num_N)




if __name__ == '__main__':
    args = init()
    train_and_evaluate(args)



