from argparse import Namespace
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
from es import SimpleGA, OpenES, CMAES
from torch.nn import ModuleList
from torch.optim import Adam

from backbone.pec_modules import get_single_pec_network
from datasets import get_dataset
from models.utils.continual_model import ContinualModel, get_lr_scheduler
from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser, str2bool
from utils.loggers import wandb_safe_log


def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description="CIL with PEC")
    add_management_args(parser)
    add_experiment_args(parser)
    add_rehearsal_args(parser)

    parser.add_argument("--make_equal_task_sizes", type=str2bool, default=False)
    parser.add_argument("--sqrt_score", type=str2bool, default=False)
    parser.add_argument("--unfair_calibration", type=str2bool, default=False)
    parser.add_argument("--evolution_calibration", type=str2bool, default=False)
    parser.add_argument("--calibration_data_source", type=str)
    parser.add_argument("--es_calibration_use_biases", type=str2bool, default=False)
    parser.add_argument("--es_calibration_weight_kind", type=str, default="exp")
    parser.add_argument("--es_calibration_algorithm", type=str, default="simplega")
    parser.add_argument("--es_calibration_sigma_init", type=float)
    parser.add_argument("--es_calibration_popsize", type=int)
    parser.add_argument("--es_calibration_elite_ratio", type=float)
    parser.add_argument("--es_calibration_forget_best", type=str2bool)
    parser.add_argument("--es_calibration_iterations", type=int, default=10_000)

    return parser


class Pec(ContinualModel):
    NAME = "pec"
    COMPATIBILITY = ["class-il", "task-il"]

    def __init__(self, backbone: nn.Module, loss: nn.Module,
                 dataset_config: dict, args: Namespace, transform: nn.Module) -> None:
        super(Pec, self).__init__(backbone, loss, dataset_config, args, transform)

        self.loss = None

        num_classes = dataset_config["classes"]
        teacher = get_single_pec_network(dataset_config, args, is_teacher=True)
        self.pec_modules = ModuleList(
            [PecStudentTeacherPair(dataset_config, args, teacher=teacher) for _ in range(num_classes)])
        for i in range(1, num_classes):
            self.pec_modules[i].student.load_state_dict(self.pec_modules[0].student.state_dict())
        self.net = self.pec_modules

        self.opt = [self.optim_class(module.student.parameters(), lr=self.args.lr) for module in self.pec_modules]
        self.per_step_lr_scheduler = [None for _ in range(num_classes)]

        self.calibration_w = torch.nn.parameter.Parameter(torch.ones(dataset_config["classes"], requires_grad=True, device=self.device))
        self.calibration_b = torch.nn.parameter.Parameter(torch.zeros(dataset_config["classes"], requires_grad=True, device=self.device))

        self.task_size = None

    def reset_optimizer(self):
        raise NotImplementedError

    def reset_per_step_lr_scheduler(self, num_steps, class_start, class_end):
        if self.task_size:
            num_steps = min(num_steps, self.task_size * self.args.n_epochs)
        steps_per_class = num_steps // (class_end - class_start)
        for c in range(class_start, class_end):
            self.per_step_lr_scheduler[c] = get_lr_scheduler(self.opt[c], self.args, num_steps=steps_per_class)

    def step_lr_scheduler(self, labels):
        label = int(labels[0])
        # Technical assumption. Can be guaranteed by one-class tasks or batch_size=1.
        assert torch.all(labels == label)

        self.per_step_lr_scheduler[label].step()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        scores = []
        with torch.no_grad():
            for module in self.pec_modules:
                scores.append(module(x))
        scores = torch.stack(scores, dim=1)
        if self.args.sqrt_score:
            scores = torch.sqrt(scores)
        scores = -scores * self.calibration_w + self.calibration_b
        return scores

    def observe(self, inputs, labels, not_aug_inputs):
        label = int(labels[0])
        # Technical assumption. Can be guaranteed by one-class tasks or batch_size=1.
        assert torch.all(labels == label)

        module = self.pec_modules[label]
        opt = self.opt[label]

        opt.zero_grad()
        loss = torch.mean(module(inputs))
        loss.backward()
        opt.step()

        return loss.item()

    def unfair_calibration(self, dataset):
        print('doing calibration...')

        args_joint = deepcopy(self.args)
        args_joint.classes_per_task = dataset.NUM_CLASSES
        args_joint.classes_first_task = dataset.NUM_CLASSES
        args_joint.batch_size = 32
        joint_data = get_dataset(args_joint)
        train_loader, test_loader = joint_data.get_data_loaders()
        n_epochs = 5
        cross_ent = torch.nn.CrossEntropyLoss()
        opti = Adam([self.calibration_w])
        for _ in range(n_epochs):
            for x, y in test_loader:
                x, y = x.to(self.device), y.to(self.device)
                logits = -torch.stack([module(x) for module in self.pec_modules], dim=1).detach() * self.calibration_w
                loss = cross_ent(logits, y)
                opti.zero_grad()
                loss.backward()
                opti.step()

        print("calibration done, ", self.calibration_w)

    def evolution_calibration(self, dataset):
        print('doing evolution calibration...')

        test_y, test_logits = self.get_data_for_calibration("test", dataset.NUM_CLASSES, dataset.NUM_CLASSES)

        if self.args.calibration_data_source == "test":
            train_y, train_logits = test_y, test_logits
        elif self.args.calibration_data_source == "train500":
            samples_per_class = 500 // dataset.NUM_CLASSES
            train_y, train_logits = self.get_data_for_calibration("train", dataset.NUM_CLASSES, 1, samples_per_class)
        else:
            assert False


        if self.args.es_calibration_algorithm == "simplega":
            ga = SimpleGA(num_params=2 * dataset.NUM_CLASSES,  # number of model parameters
                          sigma_init=self.args.es_calibration_sigma_init,  # initial standard deviation
                          popsize=self.args.es_calibration_popsize,  # population size
                          elite_ratio=self.args.es_calibration_elite_ratio,  # percentage of the elites
                          forget_best=self.args.es_calibration_forget_best,  # forget the historical best elites
                          weight_decay=0.0,  # weight decay coefficient
            )
        elif self.args.es_calibration_algorithm == "openes":
            ga = OpenES(num_params=2 * dataset.NUM_CLASSES,  # number of model parameters
                        sigma_init=self.args.es_calibration_sigma_init,  # initial standard deviation
                        popsize=self.args.es_calibration_popsize,  # population size
                        forget_best=self.args.es_calibration_forget_best,  # forget the historical best elites
                        weight_decay=0.0,  # weight decay coefficient
            )
        elif self.args.es_calibration_algorithm == "cmaes":
            ga = CMAES(num_params=2 * dataset.NUM_CLASSES,  # number of model parameters
                       sigma_init=self.args.es_calibration_sigma_init,  # initial standard deviation
                       popsize=self.args.es_calibration_popsize,  # population size
                       weight_decay=0.0,  # weight decay coefficient
            )
        else:
            assert False

        w, b = evolution_solve(ga, train_y, train_logits, test_y, test_logits, args=self.args)

        self.calibration_w.data[:] = w
        self.calibration_b.data[:] = b

        print("calibration done, ", w, b)


    def get_data_for_calibration(self, train_or_test, num_classes, classes_per_task, samples_per_task=None):
        args_copy = deepcopy(self.args)
        args_copy.classes_per_task = classes_per_task
        args_copy.classes_first_task = classes_per_task
        args_copy.batch_size = 32
        dataset = get_dataset(args_copy)
        num_tasks = num_classes // classes_per_task

        all_y = []
        all_logits = []

        for _ in range(num_tasks):
            task_y = []
            task_logits = []

            train_loader, test_loader = dataset.get_data_loaders()
            loader = {"train": train_loader, "test": test_loader}[train_or_test]

            for xy in loader:
                x, y = xy[0], xy[1]
                x, y = x.to(self.device), y.to(self.device)
                logits = -torch.stack([module(x) for module in self.pec_modules], dim=1).detach()
                task_y.append(y.cpu())
                task_logits.append(logits.cpu())

            task_y = torch.cat(task_y).numpy()
            task_logits = torch.cat(task_logits).numpy()

            if samples_per_task is not None:
                task_y = task_y[:samples_per_task]
                task_logits = task_logits[:samples_per_task]

            all_y.append(task_y)
            all_logits.append(task_logits)

        all_y = np.concatenate(all_y)
        all_logits = np.concatenate(all_logits)

        return all_y, all_logits


def get_weights_and_biases(coefs, args):
    num_classes = coefs.shape[1] // 2
    weights, biases = coefs[:, :num_classes], coefs[:, num_classes:]

    if not args.es_calibration_use_biases:
        biases = np.zeros_like(biases)
    if args.es_calibration_weight_kind == "exp":
        weights = np.exp(weights)
    elif args.es_calibration_weight_kind == "plus1":
        weights = np.maximum(weights + 1.0, 0.0)
    else:
        assert False

    return weights, biases


def get_score(y, logits, coefs, args):
    weights, biases = get_weights_and_biases(coefs, args)

    logits_scaled = weights[:, None, :] * logits[None, :, :] + biases[:, None, :]
    chosen = np.argmax(logits_scaled, axis=2)
    accuracy = 100 * (chosen == y).sum(axis=1) / len(y)
    return accuracy


def evolution_solve(solver, y, logits, test_y, test_logits, args):
    baseline_score = get_score(test_y, test_logits, np.zeros((1, solver.num_params)), args)[0]
    print("baseline score", baseline_score)

    for j in range(args.es_calibration_iterations):
        solutions = solver.ask()
        fitness_list = get_score(y, logits, coefs=solutions, args=args)
        solver.tell(fitness_list)
        res = solver.result() # first element is the best solution, second element is the best fitness
        best_solution, best_fitness = res[0], res[1]
        if (j+1) % 100 == 0:
            print("fitness at iteration", (j+1), best_fitness)
    print("local optimum discovered by solver:\n", best_solution)
    print("fitness score at this local optimum:", best_fitness)

    final_score = get_score(test_y, test_logits, best_solution[None, :], args)[0]
    advantage = final_score - baseline_score
    print("Got advantage of ", advantage)
    wandb_safe_log({"es_advantage": advantage})

    w, b = get_weights_and_biases(best_solution[None, :], args)
    return torch.tensor(w[0]), torch.tensor(b[0])


class PecStudentTeacherPair(nn.Module):
    def __init__(self, dataset_config, args, teacher):
        super().__init__()
        self.student = get_single_pec_network(dataset_config, args)
        self.teacher = teacher

    def forward(self, x):
        return torch.mean((self.student(x) - self.teacher(x).detach()) ** 2, dim=1)
