import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import IncrementalNet
from utils.toolkit import target2onehot, tensor2numpy
import copy

EPSILON = 1e-8

init_epoch = 200
init_lr = 0.1
init_milestones = [60, 120, 170]
init_lr_decay = 0.2
init_weight_decay = 0.0005

epochs = 200  # 70 -> 200
lrate = 0.1
milestones = [60, 120, 170]
lrate_decay = 0.2
batch_size = 90
weight_decay = 0.0005
num_workers = 4
T = 2

'''
init_epoch = 200
init_lr = 0.1
init_milestones = [60, 120, 170]
init_lr_decay = 0.1
init_weight_decay = 0.0005


epochs = 70
lrate = 0.1
milestones = [30, 50]
lrate_decay = 0.1
batch_size = 128
weight_decay = 2e-4
num_workers = 4
T = 2
'''

class Replay(BaseLearner):
    def __init__(self, args):
        global epochs
        global milestones
        super().__init__(args)
        self._network = IncrementalNet(args, False)
        self._snapshot = None
        self.best_accuracy = 0  # Initialize best accuracy for each task
        self.best_net_parameters = None  # Store best convnet parameters for each task
        '''
        if args['epochs']:
            epochs = args['epochs']
            print("Custom Epochs -- {a}".format(a= epochs))
            if epochs == 70:
                milestones = [30,50]
            elif epochs == 100:
                milestones = [50,70]
            elif epochs == 150:
                milestones = [60,100]
            elif epochs == 200:
                milestones = [60, 120, 170]
            elif epochs == 250:
                milestones = [100, 160, 210]
            elif epochs == 300:
                milestones = [90, 180, 240]
            elif epochs == 350:
                milestones = [100, 200, 300]
        '''
        self.losses= list()
    def after_task(self):
        self._known_classes = self._total_classes
        logging.info("Exemplar size: {}".format(self.exemplar_size))

    def incremental_train(self, data_manager):
        # Reset best accuracy and model parameters at the start of each new task
        #self.best_accuracy = 0
        #self.best_net_parameters = None

        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
        self._network.update_fc(self._total_classes)
        logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

        # Loader
        train_dataset = data_manager.get_dataset(
            np.arange(self._known_classes, self._total_classes),
            source="train",
            mode="train",
            appendent=self._get_memory(),
        )
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test")
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        # Procedure
        if len(self._multiple_gpus) > 1:
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._train(self.train_loader, self.test_loader)

        self.build_rehearsal_memory(data_manager, self.samples_per_class)
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        self.best_accuracy = 0
        self.best_net_parameters = None
        self._network.to(self._device)
        if self._cur_task == 0:
            optimizer = optim.SGD(
                self._network.parameters(), momentum=0.9, lr=init_lr, weight_decay=init_weight_decay
            )
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay)
            self._init_train(train_loader, test_loader, optimizer, scheduler)
        else:
            optimizer = optim.SGD(
                self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=weight_decay,
            )
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=lrate_decay)
            self._update_representation(train_loader, test_loader, optimizer, scheduler)

    def _init_train(self, train_loader, test_loader, optimizer, scheduler):
        prog_bar = tqdm(range(init_epoch))
        for _, epoch in enumerate(prog_bar):
            self._network.train()
            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                logits = self._network(inputs)["logits"]

                loss = F.cross_entropy(logits, targets)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()

                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            
            # Compute test accuracy at the end of each epoch
            test_acc = self._compute_accuracy(self._network, test_loader)
            if test_acc >= self.best_accuracy:  # Update best accuracy for the current task
                self.best_accuracy = test_acc
                self.best_net_parameters = copy.deepcopy(self._network.state_dict())
            info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc, test_acc,
            )
            prog_bar.set_description(info)
            ##log
            self.losses.append(losses / len(train_loader))
            ##
        logging.info(info)

    def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
        prog_bar = tqdm(range(epochs))
        for _, epoch in enumerate(prog_bar):
            self._network.train()
            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                logits = self._network(inputs)["logits"]

                loss_clf = F.cross_entropy(logits, targets)
                loss = loss_clf

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()

                # Calculate accuracy
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

            # Compute test accuracy at the end of each epoch
            test_acc = self._compute_accuracy(self._network, test_loader)
            if test_acc > self.best_accuracy:  # Update best accuracy for the current task
                self.best_accuracy = test_acc
                self.best_net_parameters = copy.deepcopy(self._network.state_dict())  # Save convnet
            info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                self._cur_task, epoch + 1, epochs, losses / len(train_loader), train_acc, test_acc,
            )
            prog_bar.set_description(info)
            ##log
            self.losses.append(losses / len(train_loader))
            ##
        logging.info(info)

    # Load the best convnet parameters for the current task
    def load_best_net_for_task(self):
        if self.best_net_parameters:
            self._network.load_state_dict(self.best_net_parameters)
            logging.info("Loaded best net for previous task with accuracy {:.2f}".format(self.best_accuracy))

'''
import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import IncrementalNet
from utils.toolkit import target2onehot, tensor2numpy
import copy

EPSILON = 1e-8

init_epoch = 200
init_lr = 0.1
init_milestones = [60, 120, 170]
init_lr_decay = 0.2
init_weight_decay = 0.0005

epochs = 200  # 70 -> 200
lrate = 0.1
milestones = [60, 120, 170]
lrate_decay = 0.2
batch_size = 90
weight_decay = 0.0005
num_workers = 4
T = 2

class Replay(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self._network = IncrementalNet(args, False)
        self._snapshot = None
        self.best_accuracy = 0  # Initialize best accuracy
        self.best_convnet_parameters = None  # Store best convnet parameters

    def after_task(self):
        self._known_classes = self._total_classes
        logging.info("Exemplar size: {}".format(self.exemplar_size))

    def incremental_train(self, data_manager):
        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
        self._network.update_fc(self._total_classes)
        logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

        # Loader
        train_dataset = data_manager.get_dataset(
            np.arange(self._known_classes, self._total_classes),
            source="train",
            mode="train",
            appendent=self._get_memory(),
        )
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test")
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        # Procedure
        if len(self._multiple_gpus) > 1:
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        self._train(self.train_loader, self.test_loader)

        self.build_rehearsal_memory(data_manager, self.samples_per_class)
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)
        if self._cur_task == 0:
            optimizer = optim.SGD(
                self._network.parameters(), momentum=0.9, lr=init_lr, weight_decay=init_weight_decay
            )
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay)
            self._init_train(train_loader, test_loader, optimizer, scheduler)
        else:
            optimizer = optim.SGD(
                self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=weight_decay,
            )
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=lrate_decay)
            self._update_representation(train_loader, test_loader, optimizer, scheduler)

    def _init_train(self, train_loader, test_loader, optimizer, scheduler):
        prog_bar = tqdm(range(init_epoch))
        for _, epoch in enumerate(prog_bar):
            self._network.train()
            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                logits = self._network(inputs)["logits"]

                loss = F.cross_entropy(logits, targets)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()

                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

            # Track the best model based on test accuracy
            if epoch % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                if test_acc > self.best_accuracy:  # Update best accuracy
                    self.best_accuracy = test_acc
                    self.best_convnet_parameters = copy.deepcopy(self._network.convnet.state_dict())  # Save convnet state
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc, test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc,
                )

            prog_bar.set_description(info)

        logging.info(info)

    def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
        prog_bar = tqdm(range(epochs))
        for _, epoch in enumerate(prog_bar):
            self._network.train()
            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                logits = self._network(inputs)["logits"]

                loss_clf = F.cross_entropy(logits, targets)
                loss = loss_clf

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()

                # Calculate accuracy
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            if epoch % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                if test_acc > self.best_accuracy:  # Track best accuracy
                    self.best_accuracy = test_acc
                    self.best_convnet_parameters = copy.deepcopy(self._network.convnet.state_dict())  # Save convnet
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task, epoch + 1, epochs, losses / len(train_loader), train_acc, test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task, epoch + 1, epochs, losses / len(train_loader), train_acc,
                )
            prog_bar.set_description(info)

        logging.info(info)

    # Load the best convnet parameters only once after all tasks
    def load_best_convnet(self):
        if self.best_convnet_parameters:
            self._network.convnet.load_state_dict(self.best_convnet_parameters)
            logging.info("Loaded best convnet with accuracy {:.2f}".format(self.best_accuracy))
'''