from copy import deepcopy
from os.path import join as pjoin

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import defaultdict

from data.sine import SINE
from utils import simplex_proj, get_logger, cg_solve
from models.building_blocks import MLP


class MAML():
    def __init__(self, task_name, inner_lr, meta_lr, K=10, inner_steps=1, tasks_per_meta_batch=25, results_path="./results", mode="skewed"):
        # Construct Model
        if task_name == "sine":
            self.task = SINE()
            self.model = MLP()
        else:
            raise NotImplementedError
        
        # Set Optimizer
        self.weights = list(self.model.parameters())  # the maml weights we will be meta-optimising
        self.criterion = nn.MSELoss()
        self.meta_optimiser = optim.Adam(self.weights, meta_lr)

        # Hyperparamters
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.K = K
        self.inner_steps = inner_steps  # with the current design of MAML, >1 is unlikely to work well
        self.tasks_per_meta_batch = tasks_per_meta_batch

        # metrics
        self.plot_every = 10
        self.print_every = 100
        self.train_info = defaultdict(list)
        
        # logdir
        self.results_path = results_path
        log_path = pjoin(results_path, "execution.log")
        self.logger = get_logger(log_path)
        
        self.mode = mode

    def inner_loop(self, X, y):
        X_train, X_test = X[:self.K], X[self.K:]
        y_train, y_test = y[:self.K], y[self.K:]

        # reset inner model to current maml weights
        temp_weights = [w.clone() for w in self.weights]
        for step in range(self.inner_steps):
            loss = self.criterion(self.model.parameterised(X_train, temp_weights), y_train)
            
            # compute grad and update inner loop weights
            grad = torch.autograd.grad(loss, temp_weights, create_graph=True, retain_graph=True)
            temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]

        loss = self.criterion(self.model.parameterised(X_test, temp_weights), y_test)
        return loss
    
    def train_log(self, epoch_loss, iteration, num_iterations):
        epoch_loss = np.array(epoch_loss)
        self.logger.info(f"{iteration}/{num_iterations}")
        self.logger.info(f"MSE(mean): {np.mean(epoch_loss):.4f}\tMSE(worst): {np.max(epoch_loss):.4f}")
        self.logger.info(f"MSE(std): {np.std(epoch_loss):.4f}\tMSE(Top 90%): {np.mean(np.sort(epoch_loss)[:int(0.9*len(epoch_loss))]):.4f}")
        
        self.train_info["loss_mean"].append(np.mean(epoch_loss))
        self.train_info["loss_std"].append(np.std(epoch_loss))
        self.train_info["loss_worst"].append(np.max(epoch_loss))
        self.train_info["loss_top90"].append(np.mean(np.sort(epoch_loss)[:int(0.9*len(epoch_loss))]))
    
    def train(self, num_iterations):
        losses = []
        for iteration in tqdm(range(1, num_iterations+1)):

            # compute meta loss
            meta_loss = 0.0
            batch_X, batch_y = self.task.sample_data(batch_size=self.tasks_per_meta_batch,
                                                     num_samples=2*self.K, mode=self.mode)
            for i in range(self.tasks_per_meta_batch):
                loss = self.inner_loop(batch_X[i], batch_y[i])
                meta_loss += loss
                losses.append(loss.item())
            
            # compute meta gradient of loss with respect to maml weights
            meta_grads = torch.autograd.grad(meta_loss, self.model.parameters(), retain_graph=True, create_graph=True)

            # assign meta gradient to weights and take optimisation step
            for w, g in zip(self.weights, meta_grads):
                w.grad = g
            self.meta_optimiser.step()

            # log metrics
            # epoch_loss.append(meta_loss.item() / self.tasks_per_meta_batch)
            if iteration % self.print_every == 0:
                self.train_log(losses, iteration, num_iterations)
                losses = []
        np.savez_compressed(f"{self.results_path}/training_curve.npz", loss_mean=np.array(self.train_info["loss_mean"]),
                                                                       loss_std=np.array(self.train_info["loss_std"]),
                                                                       loss_worst=np.array(self.train_info["loss_worst"]),
                                                                       loss_top90=np.array(self.train_info["loss_top90"]))

    def evaluate(self, num_tasks, K=5, n_steps=5, lr=0.001, save=True, mode="skewed"):
        losses = [[] for _ in range(n_steps+1)]

        # test_loss = 0.0
        X, y = self.task.sample_data(num_tasks, 2*K, mode=mode)
        for i in range(num_tasks):
            model = deepcopy(self.model)
            optimizer = optim.SGD(model.parameters(), lr=lr)
            losses[0].append(0.5 * self.criterion(model(X[i, K:]), y[i, K:]).item())

            for step in range(1, n_steps+1):
                optimizer.zero_grad()
                loss = self.criterion(model(X[i, :K]), y[i, :K])
                loss.backward()
                optimizer.step()
            
                losses[step].append(0.5 * self.criterion(model(X[i, K:]), y[i, K:]).item())

        for step in range(n_steps+1):
            losses_step = np.array(losses[step])
            results = {}
            results["mse_loss_avg"] = np.mean(losses_step)
            results["mse_loss_worst"] = np.max(losses_step)
            results["mse_loss_std"] = np.std(losses_step)
            results["mse_loss_best_5percentile"] = np.mean(np.sort(losses_step)[:int(0.05*len(losses_step))])
            results["mse_loss_best_10percentile"] = np.mean(np.sort(losses_step)[:int(0.1*len(losses_step))])
            results["mse_loss_best_90percentile"] = np.mean(np.sort(losses_step)[:int(0.90*len(losses_step))])
            results["mse_loss_best_95percentile"] = np.mean(np.sort(losses_step)[:int(0.95*len(losses_step))])
            results["mse_loss_worst_5percentile"] = np.mean(np.sort(losses_step)[int(0.95*len(losses_step)):])
            results["mse_loss_worst_10percentile"] = np.mean(np.sort(losses_step)[int(0.90*len(losses_step)):])
            results["mse_loss_worst_90percentile"] = np.mean(np.sort(losses_step)[int(0.10*len(losses_step)):])
            results["mse_loss_worst_95percentile"] = np.mean(np.sort(losses_step)[int(0.05*len(losses_step)):])
            for k, v in results.items():
                # print(f"{k}:\t{v:.2f}")
                self.logger.info(f"{k}:\t{v:.2f}")
            np.savez_compressed(f"{self.results_path}/performance_step{step}_{mode}.npz", loss=losses_step)
        if save:
            torch.save(self.model.state_dict(), f"{self.results_path}/{self}.pt")
    
    def plot(self, num_tasks, K=5, n_steps=5, lr=0.01):
        X, y = self.task.sample_data(batch_size=num_tasks, mode="plot")
        sampled_steps = [1, n_steps]
        
        for i in range(num_tasks):
            losses = []
            pred_ys = []
            idx = torch.randint(1000, size=(2*K, ))
            
            model = deepcopy(self.model)
            optimizer = optim.SGD(model.parameters(), lr=lr)
            
            pred_ys.append(model(X[i]))
            
            for step in range(1, n_steps+1):
                optimizer.zero_grad()
                loss = self.criterion(model(X[i, idx[:K]]), y[i, idx[:K]])
                loss.backward()
                optimizer.step()
                
                if step in sampled_steps:
                    pred_ys.append(model(X[i]))
            
                losses.append(loss.item())

            plt.figure(figsize=(14.4, 4.8))
            
            # plot the model functions
            plt.subplot(1, 2, 1)
            
            plt.plot(X[i], y[i], '-', color=(0, 0, 1, 0.5), label='true function')
            plt.scatter(X[i, idx[:K]], y[i, idx[:K]], label='data')
            plt.plot(X[i], pred_ys[0].detach().numpy(), ':', color=(0.7, 0, 0, 1), label='initial weights')
            
            for j, step in enumerate(sampled_steps):
                plt.plot(X[i], pred_ys[j+1].detach().numpy(), 
                        '-.' if step == 1 else '-', color=(0.5, 0, 0, 1),
                        label='model after {} steps'.format(step))
                
            plt.legend(loc='lower right')
            plt.title(f"Model fit: {str(self)}")

            # plot losses
            plt.subplot(1, 2, 2)
            plt.plot(losses)
            plt.title("Loss over time")
            plt.xlabel("gradient steps taken")
            # plt.show()
            plt.savefig(f"{self.results_path}/sample_{str(self)}_{i}.png")
    
    def __str__(self):
        return "MAML"  
    
# NashMAML with regularization term ||\theta - avg(\phi_i)||
class Aggregated_Penalty_MAML(MAML):
    def __init__(self, task_name, inner_lr, meta_lr, lam=1, num_ve_iterations=5, alpha=1/25, norm=2, update_concurrent=True,
                  K=10, inner_steps=1, tasks_per_meta_batch=25, results_path="./results", mode="skewed"):
        super().__init__(task_name, inner_lr, meta_lr, K, inner_steps, tasks_per_meta_batch, results_path, mode)

        self.lam = lam
        self.norm = norm
        self.alpha = alpha
        self.num_ve_iterations = num_ve_iterations
        
    def inner_loop(self, X, y, temp_weights, avg_weights, compute_loss=False):
        X_train, X_test = X[:self.K], X[self.K:]
        y_train, y_test = y[:self.K], y[self.K:]

        if compute_loss:
            loss = self.criterion(self.model.parameterised(X_test, temp_weights), y_test)
        else:
            for _ in range(self.inner_steps):
                loss = self.criterion(self.model.parameterised(X_train, temp_weights), y_train)
                regu_loss = 0.0
                
                for w_0, w_avg in zip(self.weights, avg_weights):
                    regu_loss += 0.5 * self.lam * torch.sum((w_0 - w_avg)**2)
                
                loss += regu_loss

                # compute grad and update inner loop weights
                grad = torch.autograd.grad(loss, temp_weights, create_graph=True, retain_graph=True)
                temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]

            loss = 0.0

        return temp_weights, loss
        
    def train(self, num_iterations):
        losses = []

        for iteration in tqdm(range(1, num_iterations+1)):

            # compute meta loss
            meta_loss = 0.0
            task_weights_list = [[w.clone() for w in self.weights] for _ in range(self.tasks_per_meta_batch)]
            batch_X, batch_y = self.task.sample_data(batch_size=self.tasks_per_meta_batch,
                                                     num_samples=2*self.K, mode=self.mode)

            for _ in range(self.num_ve_iterations):
                avg_weights = [torch.sum(torch.stack(w))/self.tasks_per_meta_batch for w in zip(*task_weights_list)]

                task_weights_list_tmp = []

                for i in range(self.tasks_per_meta_batch):
                    task_weights, _ = self.inner_loop(batch_X[i], batch_y[i], task_weights_list[i], avg_weights)
                    task_weights_list_tmp.append(task_weights)
                
                task_weights_list = task_weights_list_tmp

            avg_weights = [torch.sum(torch.stack(w))/self.tasks_per_meta_batch for w in zip(*task_weights_list)]
            for i in range(self.tasks_per_meta_batch):
                _, loss = self.inner_loop(batch_X[i], batch_y[i], task_weights_list_tmp[i], avg_weights, True)
                meta_loss += loss
                losses.append(loss.item())
            
            # compute meta gradient of loss with respect to maml weights
            meta_grads = torch.autograd.grad(meta_loss, self.weights, retain_graph=True, create_graph=True)

            # assign meta gradient to weights and take optimisation step
            for w, g in zip(self.weights, meta_grads):
                w.grad = g
            torch.nn.utils.clip_grad_norm_(self.weights, 10.0)
            self.meta_optimiser.step()
            
           # log metrics
            if iteration % self.print_every == 0:
                self.train_log(losses, iteration, num_iterations)
                losses = []
        np.savez_compressed(f"{self.results_path}/training_curve.npz", loss_mean=np.array(self.train_info["loss_mean"]),
                                                                       loss_std=np.array(self.train_info["loss_std"]),
                                                                       loss_worst=np.array(self.train_info["loss_worst"]),
                                                                       loss_top90=np.array(self.train_info["loss_top90"]))

    
    def __str__(self):
        return "Aggregated_Penalty_MAML"

# NashMAML with regularization term alpha*||\theta - \phi_i|| + (1-alpha)*||\phi_i-(sum(\phi))/(M-1)||
class Meta_Separated_Penalty_MAML(MAML):
    def __init__(self, task_name, inner_lr, meta_lr, lam=1, num_ve_iterations=5, alpha=1/25, norm=2, update_concurrent=True,
                  K=10, inner_steps=1, tasks_per_meta_batch=25, results_path="./results"):
        super().__init__(task_name, inner_lr, meta_lr, K, inner_steps, tasks_per_meta_batch, results_path)

        self.lam = lam
        self.norm = norm
        self.alpha = alpha
        self.num_ve_iterations = num_ve_iterations
        self.update_concurrent = update_concurrent
        
    def inner_loop(self, X, y, temp_weights, other_weights_list, compute_loss=False):
        X_train, X_test = X[:self.K], X[self.K:]
        y_train, y_test = y[:self.K], y[self.K:]
    
        if compute_loss:
            loss = self.criterion(self.model.parameterised(X_test, temp_weights), y_test)
        else:
            for step in range(self.inner_steps):
                loss = self.criterion(self.model.parameterised(X_train, temp_weights), y_train)
                regu_loss = 0.0
                #for w in zip(*other_weights_list):
                #    print(w)
                avg_weights = [torch.sum(torch.stack(w))/self.tasks_per_meta_batch for w in zip(*other_weights_list)]
                for w, w_0, w_avg in zip(temp_weights, self.weights, avg_weights):
                    regu_loss += 0.5 * self.lam * self.alpha * torch.sum(torch.abs((w_0 - w)**self.norm))
                    regu_loss += 0.5 * self.lam * (1-self.alpha) * torch.sum(torch.abs((w - w_avg)**self.norm))
                loss += regu_loss
                # compute grad and update inner loop weights
                grad = torch.autograd.grad(loss, temp_weights, create_graph=True, retain_graph=True)
                temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]
                        
            loss = 0.0
            
        return temp_weights, loss
        
    def train(self, num_iterations):
        losses = []
        
        for iteration in range(1, num_iterations+1):
            
            # compute meta loss
            meta_loss = 0.0
            task_weights_list = [[w.clone() for w in self.weights] for _ in range(self.tasks_per_meta_batch)]
            batch_X, batch_y = self.task.sample_data(batch_size=self.tasks_per_meta_batch,
                                                     num_samples=2*self.K, mode="skewed")

            for j in range(self.num_ve_iterations):
                task_weights_list_tmp = [[w.clone() for w in t_w] for t_w in task_weights_list]
                for i in range(self.tasks_per_meta_batch):
                    if self.update_concurrent:
                        other_task_weights_list = task_weights_list[:i]+task_weights_list[i+1:]
                    else:
                        other_task_weights_list = task_weights_list_tmp[:i]+task_weights_list_tmp[i+1:]
                    task_weights, _ = self.inner_loop(batch_X[i], batch_y[i], task_weights_list[i], other_task_weights_list)
                    task_weights_list_tmp[i] = task_weights

            for i in range(self.tasks_per_meta_batch):
                other_task_weights_list = task_weights_list_tmp[:i]+task_weights_list_tmp[i+1:]
                _, loss = self.inner_loop(batch_X[i], batch_y[i], task_weights_list_tmp[i], other_task_weights_list, True)
                meta_loss += loss
                losses.append(loss.item())
            
            # compute meta gradient of loss with respect to maml weights
            meta_grads = torch.autograd.grad(meta_loss, self.weights, retain_graph=True, create_graph=True)
            
            # assign meta gradient to weights and take optimisation step
            for w, g in zip(self.weights, meta_grads):
                w.grad = g
            torch.nn.utils.clip_grad_norm_(self.weights, 10.0)
            self.meta_optimiser.step()
            
           # log metrics
            if iteration % self.print_every == 0:
                self.train_log(losses, iteration, num_iterations)
                losses = []
    
    def __str__(self):
        return "Meta_Separated_Penalty_MAML"
    
# NashMAML with regularization term alpha*||\theta - \phi_i|| + (1-alpha)/(M-1)*sum(||\phi_i-\phi_j||)
class Separated_Penalty_MAML(MAML):
    def __init__(self, task_name, inner_lr, meta_lr, lam=1, num_ve_iterations=5, alpha=1/25, norm=2, update_concurrent=True,
                  K=10, inner_steps=1, tasks_per_meta_batch=25, results_path="./results"):
        super().__init__(task_name, inner_lr, meta_lr, K, inner_steps, tasks_per_meta_batch, results_path)

        self.lam = lam
        self.norm = norm
        self.alpha = alpha
        self.num_ve_iterations = num_ve_iterations
        self.update_concurrent = update_concurrent
        
    def inner_loop(self, X, y, temp_weights, other_weights_list, compute_loss=False):
        X_train, X_test = X[:self.K], X[self.K:]
        y_train, y_test = y[:self.K], y[self.K:]
    
        if compute_loss:
            loss = self.criterion(self.model.parameterised(X_test, temp_weights), y_test)
        else:
            for step in range(self.inner_steps):
                loss = self.criterion(self.model.parameterised(X_train, temp_weights), y_train)
                regu_loss = 0.0
                for w, w_0 in zip(temp_weights, self.weights):
                    regu_loss += 0.5 * self.lam * self.alpha * torch.sum(torch.abs((w_0 - w)**self.norm))
                for w_other_task_weights in other_weights_list:
                    for w, w_j in zip(temp_weights, w_other_task_weights):
                        regu_loss += 0.5 * self.lam * (1 - self.alpha) / (self.tasks_per_meta_batch - 1) * torch.sum(torch.abs((w - w_j)**self.norm))
                loss += regu_loss
                # compute grad and update inner loop weights
                grad = torch.autograd.grad(loss, temp_weights, create_graph=True, retain_graph=True)
                temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]
                        
            loss = 0.0
            
        return temp_weights, loss
        
    def train(self, num_iterations):
        losses = []
        
        for iteration in range(1, num_iterations+1):
            
            # compute meta loss
            meta_loss = 0.0
            task_weights_list = [[w.clone() for w in self.weights] for _ in range(self.tasks_per_meta_batch)]
            batch_X, batch_y = self.task.sample_data(batch_size=self.tasks_per_meta_batch,
                                                     num_samples=2*self.K, mode="skewed")

            for j in range(self.num_ve_iterations):
                task_weights_list_tmp = [[w.clone() for w in t_w] for t_w in task_weights_list]
                for i in range(self.tasks_per_meta_batch):
                    if self.update_concurrent:
                        other_task_weights_list = task_weights_list[:i]+task_weights_list[i+1:]
                    else:
                        other_task_weights_list = task_weights_list_tmp[:i]+task_weights_list_tmp[i+1:]
                    task_weights, _ = self.inner_loop(batch_X[i], batch_y[i], task_weights_list[i], other_task_weights_list)
                    task_weights_list_tmp[i] = task_weights

            for i in range(self.tasks_per_meta_batch):
                other_task_weights_list = task_weights_list_tmp[:i]+task_weights_list_tmp[i+1:]
                _, loss = self.inner_loop(batch_X[i], batch_y[i], task_weights_list_tmp[i], other_task_weights_list, True)
                meta_loss += loss
                losses.append(loss.item())
            
            # compute meta gradient of loss with respect to maml weights
            meta_grads = torch.autograd.grad(meta_loss, self.weights, retain_graph=True, create_graph=True)
            
            # assign meta gradient to weights and take optimisation step
            for w, g in zip(self.weights, meta_grads):
                w.grad = g
            torch.nn.utils.clip_grad_norm_(self.weights, 10.0)
            self.meta_optimiser.step()
            
           # log metrics
            if iteration % self.print_every == 0:
                self.train_log(losses, iteration, num_iterations)
                losses = []
    
    def __str__(self):
        return "Separated_Penalty_MAML"

# NashMAML with joint constraint sum(||\theta-\phi_i||_2^2)<=r^2
class Constrained_MAML(MAML):
    def __init__(self, task_name, inner_lr, meta_lr, radius=0.05, num_ve_iterations=5,
                  K=10, inner_steps=1, tasks_per_meta_batch=25, results_path="./results", mode="skewed"):
        super().__init__(task_name, inner_lr, meta_lr, K, inner_steps, tasks_per_meta_batch, results_path, mode)

        self.radius = radius
        self.num_ve_iterations = num_ve_iterations
        
    def inner_loop(self, X, y, temp_weights, compute_loss=False):
        X_train, X_test = X[:self.K], X[self.K:]
        y_train, y_test = y[:self.K], y[self.K:]
    
        if compute_loss:
            loss = self.criterion(self.model.parameterised(X_test, temp_weights), y_test)
        else:
            for step in range(self.inner_steps):
                loss = self.criterion(self.model.parameterised(X_train, temp_weights), y_train)
                
                # compute grad and update inner loop weights
                grad = torch.autograd.grad(loss, temp_weights, create_graph=True, retain_graph=True)
                temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]
                        
            loss = 0.0
            
        return temp_weights, loss
        
    def train(self, num_iterations):
        losses = []
        
        for iteration in tqdm(range(1, num_iterations+1)):

            # compute meta loss
            meta_loss = 0.0
            task_weights_list = [[w.clone() for w in self.weights] for _ in range(self.tasks_per_meta_batch)]
            batch_X, batch_y = self.task.sample_data(batch_size=self.tasks_per_meta_batch,
                                                     num_samples=2*self.K, mode=self.mode)
            
            for j in range(self.num_ve_iterations):
                dist_square = torch.tensor(0.)
                for i in range(self.tasks_per_meta_batch):
                    task_weights, _ = self.inner_loop(batch_X[i], batch_y[i], task_weights_list[i])
                    task_weights_list[i] = task_weights
                    dist_square += sum(list(map(lambda p: torch.sum(torch.square(p[1] - p[0])), zip(task_weights, self.weights))))
                    # dist_square += torch.sum(torch.stack([torch.sum(torch.square(p[1] - p[0])) for p in zip(task_weights, self.weights)]))
                
                d = torch.sqrt(dist_square)
                r = self.radius
                
                if d > r:
                    for i in range(self.tasks_per_meta_batch):
                        for j, (orig_w, curr_w) in enumerate(zip(task_weights_list[i], self.weights)):
                            task_weights_list[i][j] = (r*curr_w + (d-r)*orig_w) / d
                        # task_weights_list[i] = list(map(lambda p: (r*p[0] + (d-r)*p[1])/d, zip(task_weights_list[i], self.weights)))

            for i in range(self.tasks_per_meta_batch):
                _, loss = self.inner_loop(batch_X[i], batch_y[i], task_weights_list[i], True)
                meta_loss += loss
                losses.append(loss.item())
                
            # compute meta gradient of loss with respect to maml weights
            meta_grads = torch.autograd.grad(meta_loss, self.weights, retain_graph=True, create_graph=True)

            # assign meta gradient to weights and take optimisation step
            for w, g in zip(self.weights, meta_grads):
                w.grad = g
            torch.nn.utils.clip_grad_norm_(self.weights, 10.0)
            self.meta_optimiser.step()

           # log metrics
            if iteration % self.print_every == 0:
                self.train_log(losses, iteration, num_iterations)
                losses = []
        np.savez_compressed(f"{self.results_path}/training_curve.npz", loss_mean=np.array(self.train_info["loss_mean"]),
                                                                       loss_std=np.array(self.train_info["loss_std"]),
                                                                       loss_worst=np.array(self.train_info["loss_worst"]),
                                                                       loss_top90=np.array(self.train_info["loss_top90"]))

    
    def __str__(self):
        return "Constrained_MAML"
