import logging
import numpy as np
import torch
import os
from torch import nn
from torch.serialization import load
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
from utils.data_manager import DummyDataset
from utils.inc_net import SDENet
from models.base import BaseLearner
from utils.toolkit import count_parameters
from utils.toolkit import target2onehot, tensor2numpy
from torchvision import datasets, transforms
from utils.autoaugment import CIFAR10Policy
from models.Ball import BallList


EPSILON = 1e-8

class BallIL(BaseLearner):
    def __init__(self, args):

        super().__init__(args)  
        self.init_epoch = args["init_epoch"]
        self.init_lr = args["init_lr"]
        self.init_milestones = args["init_milestones"]
        self.init_lr_decay = args["init_lr_decay"]
        self.init_weight_decay = args["init_weight_decay"]
        self.epochs = args["epochs"]
        self.lrate = args["lrate"]
        self.milestones =  args["milestones"]
        self.lrate_decay = args["lrate_decay"]
        self.batch_size = args["batch_size"]
        self.weight_decay = args["weight_decay"]
        self.num_workers = args["num_workers"]
        self.shot = args["shot"]
        self.T = args["T"]
        self.w_dist = args["w_dist"]
        self.w_confu = args["w_confu"]
        self.w_ball_cls = args["w_ball_cls"]
        self.w_cons = args["w_cons"]
        self.w_concp = args["w_concp"]
        self.blur_r = args["blur_r"]

        self.args = args
        self._network = SDENet(args, False)
        self._balls = BallList(self.blur_r)


    def after_task(self):
        self._old_network = self._network.copy().freeze()
        self._known_classes = self._total_classes       
        if self.args['resume']: 
            if not os.path.exists(self.args["model_dir"]):
                os.makedirs(self.args["model_dir"])
            self.save_checkpoint("{}/{}_{}_{}".format(self.args["model_dir"],self.args["dataset"], self.args["init_cls"],self.args["increment"]))

    def incremental_train(self, data_manager):
        self.data_manager = data_manager
        self._cur_task += 1  
        
        self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
        self._network.update_fc(self._total_classes)
        
        logging.info(
            "Learning on {}-{}".format(self._known_classes, self._total_classes)
        )

        if self.shot is False:
            self.shot = None
        train_dataset = data_manager.get_dataset(
            np.arange(self._known_classes, self._total_classes), 
            source="train",
            mode="train",
            shot=self.shot 
        )
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

        test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test")
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

        if len(self._multiple_gpus) > 1:
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._train(self.train_loader, self.test_loader)

        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        resume = self.args['resume']
        file_path = "{}{}_{}_{}_{}.pkl".format(self.args["model_dir"],self.args["dataset"],self.args["init_cls"],self.args["increment"],self._cur_task)
        if self._cur_task == 0:
            if resume:
                if os.path.exists(file_path):
                    if hasattr(self._network, "module"):
                        self._network.module.load_state_dict(torch.load(file_path)["model_state_dict"], strict=True)
                    else:
                        self._network.load_state_dict(torch.load(file_path)["model_state_dict"], strict=True)
                else:
                    resume = False
            self._network.to(self._device)
            if not resume:
                self._first_task_train(train_loader, test_loader)

        else:
            resume = self.args['resume']
            if hasattr(self._network, "module"):
                self._network.module.init_de()
            else:
                self._network.init_de()
            if resume:
                if os.path.exists(file_path):
                    if hasattr(self._network, "module"):
                        self._network.module.load_state_dict(torch.load(file_path)["model_state_dict"], strict=True)
                    else:
                        self._network.load_state_dict(torch.load(file_path)["model_state_dict"], strict=True)
                else:
                    resume = False
            self._network.to(self._device)
            if self._old_network is not None:
                self._old_network.to(self._device)
            if not resume:
                self._follow_task_train(train_loader, test_loader)
            self._update_balls(train_loader)
            self._consolidate_fc(train_loader)
        self._generate_balls()
    
    def _generate_balls(self):
        for idx in range(self._known_classes, self._total_classes):
            _, _, idx_dataset = self.data_manager.get_dataset(np.arange(idx, idx+1), source='train',mode='test', shot=self.shot, ret_data=True)
            idx_loader = DataLoader(idx_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
            feas, _ = self._extract_vectors(idx_loader)
            self._balls._add_ball(feas=feas, label=idx, task_id=self._cur_task)
        
    def _update_balls(self,train_loader):
        self._drift_learner(train_loader) 
        if hasattr(self._network, "module"):
            _network = self._network.module
        else:
            _network = self._network
        _network.freeze_de()
        with torch.no_grad():
            for ball in self._balls._get_balls():
                ball._update_center(_network.drift_estimate_larer(torch.tensor(ball._get_center()).to(self._device))['logits'].detach().cpu().numpy())


    def _first_task_train(self, train_loader, test_loader):
        if hasattr(self._network, "module"):
            _network = self._network.module
        else:
            _network = self._network

        optimizer = optim.SGD(list(_network.convnet.parameters()) + list(_network.fc.parameters()), momentum=0.9, lr=self.init_lr, weight_decay=self.init_weight_decay)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=self.init_milestones, gamma=self.init_lr_decay)
        prog_bar = tqdm(range(self.init_epoch), ascii=True) 
        for _, epoch in enumerate(prog_bar):
            _network.train() 
            losses = 0.0
            L_cons = 0.0
            L_cls = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                outputs = _network(inputs)
                logits = outputs["logits"]
                feas = outputs["features"]
                loss_cls = F.cross_entropy(logits, targets.long())
                L_cls += loss_cls.item()
                loss_cons = _contr_loss(feas, targets, T = self.T)
                L_cons += loss_cons.item()
                loss = loss_cls + loss_cons* self.w_cons

                optimizer.zero_grad() 
                loss.backward()
                optimizer.step()
                losses += loss.item()   

                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

            if epoch % 25 == 0:
                test_acc = self._compute_accuracy(_network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_cls {:.3f}, Loss_cons {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.init_epoch,
                    losses / len(train_loader),
                    L_cls / len(train_loader),
                    L_cons* self.w_cons / len(train_loader),
                    train_acc,
                    test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_cls {:.3f}, Loss_cons {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.init_epoch,
                    losses / len(train_loader),
                    L_cls / len(train_loader),
                    L_cons* self.w_cons / len(train_loader),
                    train_acc,
                )
            prog_bar.set_description(info)
        logging.info(info)

    def _follow_task_train(self, train_loader, test_loader):
        if hasattr(self._network, "module"):
            _network = self._network.module
        else:
            _network = self._network
        
        epoch_update_drift = [40,60,80]
        prog_bar = tqdm(range(self.epochs),ascii=True)
        
        optimizer_conv_fc  = optim.SGD(list(_network.convnet.parameters()) + list(_network.fc.parameters()), lr=self.lrate, momentum=0.9, weight_decay=self.weight_decay)
        scheduler_conv_fc = optim.lr_scheduler.MultiStepLR(optimizer=optimizer_conv_fc, milestones=self.milestones, gamma=self.lrate_decay)
        
        balls_center= torch.tensor(self._balls._get_center()).to(self._device)
        balls_radius = torch.tensor(self._balls._get_radius(self._cur_task)).to(self._device)

        for _, epoch in enumerate(prog_bar):
            L_clf = 0.0
            L_kd = 0.0
            L_cons = 0.0
            L_concp = 0.0
            losses = 0.0
            correct, total = torch.tensor(0), torch.tensor(0)
            if epoch in epoch_update_drift:
                self._drift_learner(train_loader)
            _network.freeze_de()
            _network.unfreeze_conv()
            _network.unfreeze_fc()
            
            for i, (_, inputs, targets) in enumerate(train_loader):
                train_type = 'conv+fc'

                inputs, targets = inputs.to(self._device), targets.to(self._device)
                new_net_outputs = _network(inputs)
                new_net_logits = new_net_outputs["logits"]
                new_net_features = new_net_outputs["features"]
                fake_targets = targets - self._known_classes 
                loss_clf = F.cross_entropy(
                    new_net_logits[:, self._known_classes :], fake_targets.long()
                )
                L_clf += loss_clf.item()
                old_net_outputs = self._old_network(inputs)
                old_net_logits = old_net_outputs["logits"]
                loss_kd = _KD_loss(
                    new_net_logits[:, : self._known_classes],
                    old_net_logits,
                    self.T,
                )
                L_kd += loss_kd.item()

                
                with torch.no_grad():
                    balls_center_update =_network.drift_estimate_larer(balls_center)['logits'].detach()
                loss_cons = _contr_loss(new_net_features, targets, T = self.T)
                L_cons += loss_cons.item()

                concp_loss = _concp_loss(new_net_features, T=self.T, balls_center=balls_center_update, balls_radius=balls_radius)
                L_concp += concp_loss.item()

                loss = self.w_dist * loss_kd + loss_clf + loss_cons* self.w_cons + concp_loss * self.w_concp
                optimizer_conv_fc.zero_grad()
                loss.backward()
                optimizer_conv_fc.step()
                losses += loss.item()

                with torch.no_grad():
                    _, preds = torch.max(new_net_logits, dim=1)
                    correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                    total += len(targets)
                
            scheduler_conv_fc.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / (total+0.0001), decimals=2)
            info = "Task {}, train {}, Epoch {}/{} => Loss {:.3f}, L_clf {:.3f}, L_kd {:.3f}, L_cons {:.3f},L_concp {:.3f}, Train_accy {:.2f}".format(
                self._cur_task, train_type, epoch + 1, self.epochs,
                losses / len(train_loader), 
                L_clf / len(train_loader), 
                L_kd * self.w_dist / len(train_loader), 
                L_cons * self.w_cons/ len(train_loader), 
                L_concp * self.w_concp/ len(train_loader), 
                train_acc,
            )
            prog_bar.set_description(info)
        logging.info(info)

    def _drift_learner(self, train_loader):
        if hasattr(self._network, "module"):
            _network = self._network.module
        else:
            _network = self._network
        epoch_all = 20
        prog_bar = tqdm(range(epoch_all), ascii=True)
        optimizer_de = optim.Adam(_network.drift_estimate_larer.parameters(), lr=0.001)
        _network.freeze_conv()
        _network.freeze_fc()
        _network.unfreeze_de()
        for _, epoch in enumerate(prog_bar):
            L_de = 0.0 
            losses = 0.0
            for i, (_, inputs, targets) in enumerate(train_loader):
                train_type = 'de'
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                new_net_features = _network(inputs)["features"]
                old_net_features = self._old_network(inputs)["features"]
                estimate_old_features = _network.drift_estimate_larer(old_net_features)['logits']
                drift_estimate_loss = torch.nn.MSELoss()(estimate_old_features, new_net_features)
                L_de += drift_estimate_loss.item()
                
                loss = drift_estimate_loss
                optimizer_de.zero_grad()
                loss.backward()
                optimizer_de.step()
                losses += loss.item()
            info = "Task {}, train {}, Epoch {}/{} => Loss {:.3f}, L_de {:.3f}".format(
                    self._cur_task, train_type, epoch + 1, epoch_all,
                    losses / len(train_loader), 
                    L_de / len(train_loader),
                    )
            prog_bar.set_description(info)
        logging.info(info)

    def _consolidate_fc(self, train_loader):
        if hasattr(self._network, "module"):
            _network = self._network.module
        else:   
            _network = self._network
        _network.freeze_conv()
        _network.freeze_de()
        _network.unfreeze_fc()
        optimizer = optim.SGD(_network.fc.parameters(), lr=0.01, momentum=0.9, weight_decay=self.weight_decay)
        logging.info("All params: {}".format(count_parameters(self._network)))
        logging.info("Trainable params: {}".format(count_parameters(self._network, True)))
        epoch_num = self.args["epoch_fc"]
        balls_center = torch.tensor(self._balls._get_center()).to(self._device)
        balls_radius = torch.tensor(self._balls._get_radius(self._cur_task)).to(self._device)
        balls_targets = torch.tensor(self._balls._get_labels()).to(self._device)

        balls_radius_origin = torch.tensor(self._balls._get_radius_orig()).to(self._device)
        logging.info("balls_radius_origin: {}".format(str(balls_radius_origin)))

        for epoch in range(epoch_num):
            L_cls = 0.0
            L_cls_old = 0.0
            L_confu = 0.0
            L_all = 0.0
            old_total = 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                with torch.no_grad():
                    samp_fea = _network(inputs)["features"].detach()
                new_sample_logits = _network.fc(samp_fea)["logits"]
                
                loss_clf = F.cross_entropy(new_sample_logits, targets.long()) 
                L_cls += loss_clf.item()

                ball_logits= _network.fc(balls_center)["logits"]
                loss_ball_cls = F.cross_entropy(ball_logits, balls_targets.long())
                L_cls_old += loss_ball_cls.item()
                old_total += len(balls_targets)

                loss_confu = _confu_loss(balls_center, balls_radius, _network, self.T)
                L_confu += loss_confu.item()
                loss = loss_clf + loss_confu*self.w_confu + loss_ball_cls *self.w_ball_cls
                L_all += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            if self.args["print_info"]:
                print('recompute fc,epoch:{}/{},loss_clf:{}, L_cls_old:{}, loss_confu:{}'.format(epoch, epoch_num,
                                                                                                        L_cls/len(train_loader),
                                                                                                        L_cls_old *self.w_ball_cls/old_total,
                                                                                                        L_confu * self.w_confu/old_total
                                                                                                        ))
        _network.unfreeze_conv()
        _network.unfreeze_de()
    

def _KD_loss(pred, soft, T):
    pred = torch.log_softmax(pred / T, dim=1)
    soft = torch.softmax(soft / T, dim=1)
    return -1 * torch.mul(soft, pred).sum() / pred.shape[0]

def _confu_loss(balls_center, balls_radius, model, T):
    W = model.fc.weight
    bias = model.fc.bias
    norm_W = torch.norm(W, dim=1, keepdim=True)
    distances = torch.abs(torch.matmul(balls_center, W.t()) + bias.unsqueeze(0)) 
    distances = distances/ norm_W.t()
    dis_cent_bun = distances.min(dim=1)[0]
    prob = 1/torch.clamp(torch.exp((balls_radius - dis_cent_bun)/T), min=1.0)
    entropy =-torch.log(prob) * prob
    return entropy.sum() / len(balls_center)


def _contr_loss(features, targets, T):
    pairwise_distances = torch.cdist(features, features, p=2)
    similarity_matrix = torch.exp(-pairwise_distances**2 / (2 * T**2)) 
    
    positive_samples = targets.unsqueeze(1) == targets.unsqueeze(0) 
    negative_samples = targets.unsqueeze(1) != targets.unsqueeze(0)
    exp_similarity = torch.exp(similarity_matrix / T)
    same_class_similarity = exp_similarity * positive_samples.float()
    diff_class_similarity = exp_similarity * negative_samples.float()

    normal_similarity = same_class_similarity.sum(dim=1) / diff_class_similarity.sum(dim=1)
    loss =  -torch.log (normal_similarity) / positive_samples.sum(dim=1)
    return loss.mean()

def _concp_loss(features, T, balls_center=None, balls_radius=None):
    dis_center_smp= torch.cdist(balls_center, features, p=2)

    mask = dis_center_smp < balls_radius.unsqueeze(1)
    tmp = torch.exp((balls_radius.unsqueeze(1) - dis_center_smp) / T)*mask
    # print(tmp)
    prob = 1/ torch.clamp((tmp.sum(0)+EPSILON),min=1)
    entropy =-torch.log(prob) * prob
    return entropy.sum() / len(balls_center)