import logging
import numpy as np
import torch
from torch import nn
from torch.serialization import load
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils_online.inc_net_online import IncrementalNetOnline
from models_online.base import BaseLearner
from utils.toolkit import target2onehot, tensor2numpy
from buffer.buffer import ProtoBuffer, Reservoir, QueryReservoir
from utils_online.si_blurry import IndexedDataset, OnlineSampler, OnlineTestSampler

num_workers = 8

class Learner(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self._network = IncrementalNetOnline(args, True)
        self._network.fc = self._network.generate_fc(768, args['nb_classes'])
        self.batch_size = args["batch_size"]
        self.init_lr = args["init_lr"]
        self.weight_decay = args["weight_decay"] if args["weight_decay"] is not None else 0.0005
        self.recab_coef = args["recab_coef"]
        self.args = args

        self.buffer = Reservoir(
            max_size=1000,
            n_classes=args['nb_classes'],
            img_size=224,
            device=self._device
        )
        
        for p in self._network.parameters():
            if self.args['freeze']:
                p.requires_grad = False
            else:
                p.requires_grad = True
        for p in self._network.fc.parameters():
            p.requires_grad = True

    def after_task(self):
        self._known_classes = self._total_classes

    def incremental_train(self, data_manager, **kwargs):
        self._cur_task += 1
        n_tasks = kwargs.get('n_tasks', 10)
        # self._total_classes = 100
        # self._network.update_fc(self._total_classes)
        if self.args['blurry']:
            self._total_classes = self.args['nb_classes']
            train_dataset = data_manager.get_dataset(np.arange(0, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        else:
            self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
            # self._network.update_fc(self._total_classes)
            logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

            train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        self.test_dataset = test_dataset
        self.train_dataset = train_dataset
        if self.args['blurry']:
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            
            self.train_dataset = IndexedDataset(self.train_dataset)
            self.train_sampler = OnlineSampler(self.train_dataset, n_tasks, 10, 50, self.args['seed'], False, 1)
            self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.train_sampler, num_workers=num_workers, pin_memory=True)
            
            self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            self.train_sampler.set_task(self._cur_task)
        else:
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
            
        self.data_manager = data_manager

        if len(self._multiple_gpus) > 1:
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._train(self.train_loader, self.test_loader)
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)
        if self._cur_task == 0:
            if self.args["optimizer"] == "sgd":
                self.optim = optim.SGD(
                    self._network.parameters(),
                    momentum=0.9,
                    lr=self.args["init_lr"],
                    weight_decay=self.args["init_weight_decay"],
                )
            elif self.args["optimizer"] == "adam":
                self.optim = optim.Adam(
                    self._network.parameters(),
                    lr=self.args["init_lr"],
                )
                
        self._init_train(train_loader, test_loader)

    def _init_train(self, train_loader, test_loader):
        prog_bar = tqdm(range(self.args["init_epoch"]))
        for _, epoch in enumerate(prog_bar):
            self._network.train()
            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                mem_x, mem_y = self.buffer.random_retrieve(n_imgs=100)
                
                loss = torch.tensor(0).to(self._device)
                if mem_x.size(0) > 0:
                    # Combined batch
                    combined_x, combined_y = self.combine(inputs, targets, mem_x, mem_y)  # (batch_size, nb_channel, img_size, img_size)
                else:
                    combined_x, combined_y = inputs, targets
                    
                # Inference
                logits = self._network(combined_x)["logits"]

                # Loss
                loss = F.cross_entropy(logits, combined_y.long()).mean()
            
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                    
                losses += loss.item()
 
                # buffer update
                self.buffer.update(imgs=inputs, labels=targets)
                logits = logits[-len(inputs):,:]
                
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)


            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

            if self.args['blurry']:
                mem_x, mem_y = self.buffer.random_retrieve(n_imgs=self.args['nb_classes'])
                test_sampler = OnlineTestSampler(self.test_dataset, mem_y.unique())
                test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=test_sampler, num_workers=num_workers)
                self.test_loader = test_loader
            if (epoch + 1) % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                    test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                )
            prog_bar.set_description(info)

        logging.info(info)

    def combine(self, batch_x, batch_y, mem_x, mem_y):
        mem_x, mem_y = mem_x.to(self._device), mem_y.to(self._device)
        batch_x, batch_y = batch_x.to(self._device), batch_y.to(self._device)
        combined_x = torch.cat([mem_x, batch_x])
        combined_y = torch.cat([mem_y, batch_y])
        return combined_x, combined_y
        
    def _update_representation(self, train_loader, test_loader, optimizer, scheduler):

        prog_bar = tqdm(range(self.args["epochs"]))
        for _, epoch in enumerate(prog_bar):
            self._network.train()
            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                logits = self._network(inputs)["logits"]

                fake_targets = targets - self._known_classes
                loss_clf = F.cross_entropy(
                    logits[:, self._known_classes :], fake_targets
                )

                loss = loss_clf

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()

                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            if epoch % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args["epochs"],
                    losses / len(train_loader),
                    train_acc,
                    test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args["epochs"],
                    losses / len(train_loader),
                    train_acc,
                )
            prog_bar.set_description(info)
        logging.info(info)
