import torch
import torch.nn.functional as F
import numpy as np
import time
import sys
import higher
import os
from tqdm import tqdm

from collections import OrderedDict
from maml.utils import tensors_to_device, compute_accuracy, gradient_update_parameters, apply_grad, mix_grad
from concurrent.futures import ThreadPoolExecutor
import torch.nn.functional as F
from maml.hessianfree import HessianFree

__all__ = ['ModelAgnosticMetaLearning', 'MAML', 'FOMAML']


class ModelAgnosticMetaLearning(object): 
    def __init__(self, model, optimizer=None, step_size=0.1, first_order=False,
                 learn_step_size=False, per_param_step_size=False,
                 num_adaptation_steps=1, beta=1.0, dyfac=0.2, scheduler=None,
                 loss_function=F.cross_entropy, device=None):
        self.model = model.to(device=device)
        self.optimizer = optimizer
        self.step_size = step_size
        self.first_order = first_order
        self.num_adaptation_steps = num_adaptation_steps
        self.beta = beta
        self.dyfac = dyfac
        self.scheduler = scheduler
        self.loss_function = loss_function
        self.device = device

        if per_param_step_size:
            self.step_size = OrderedDict((name, torch.tensor(step_size,
                dtype=param.dtype, device=self.device,
                requires_grad=learn_step_size)) for (name, param)
                in model.meta_named_parameters())
        else:
            self.step_size = torch.tensor(step_size, dtype=torch.float32,
                device=self.device, requires_grad=learn_step_size)

        if (self.optimizer is not None) and learn_step_size:
            self.optimizer.add_param_group({
                'params': self.step_size.values() if per_param_step_size else [self.step_size],
                'lr':0.0000001
                })
            if scheduler is not None:
                for group in self.optimizer.param_groups:
                    group.setdefault('initial_lr', group['lr'])
                self.scheduler.base_lrs([group['initial_lr']
                    for group in self.optimizer.param_groups])

    def get_outer_loss(self, batch):
        if 'test' not in batch:
            raise RuntimeError('The batch does not contain any test dataset.')

        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)
        is_classification_task = (not test_targets.dtype.is_floating_point)

        mean_outer_loss = torch.tensor(0., device=self.device)
        outer_losses_all_tasks = []
        loss_id = []
        
        for task_id, (train_inputs, train_targets, test_inputs, test_targets) \
                in enumerate(zip(*batch['train'], *batch['test'])):
            params, adaptation_results = self.adapt(train_inputs, train_targets,
                is_classification_task=is_classification_task,
                num_adaptation_steps=self.num_adaptation_steps,
                step_size=self.step_size, first_order=self.first_order)
        
            with torch.set_grad_enabled(self.model.training):
                test_logits = self.model(test_inputs, params=params)
                outer_loss = self.loss_function(test_logits, test_targets)
                # print(outer_loss)
                mean_outer_loss += outer_loss.item()
                loss_id.append((outer_loss.item(), task_id))
                outer_losses_all_tasks.append(outer_loss)
        
        mean_outer_loss.div_(num_tasks)

        sorted_task_ids = [task_id for _, task_id in sorted(loss_id, key=lambda x: x[0])]

        reweight_mean_outer_loss = torch.tensor(0., device=self.device)
        cou = 0

        for i, task_id in enumerate(sorted_task_ids):
            if i < self.dyfac:
                reweight_mean_outer_loss += outer_losses_all_tasks[task_id] * self.beta
            else:
                reweight_mean_outer_loss += outer_losses_all_tasks[task_id] 
                cou +=1
        reweight_mean_outer_loss.div_(cou+(num_tasks-cou)*self.beta)
        return reweight_mean_outer_loss, mean_outer_loss

    def adapt(self, inputs, targets, is_classification_task=None,
              num_adaptation_steps=1, step_size=0.1, first_order=False):
        if is_classification_task is None:
            is_classification_task = (not targets.dtype.is_floating_point)
        params = None

        results = {'inner_losses': np.zeros(
            (num_adaptation_steps,), dtype=np.float32)}

        for step in range(num_adaptation_steps):
            logits = self.model(inputs, params=params)
            inner_loss = self.loss_function(logits, targets)
            results['inner_losses'][step] = inner_loss.item()

            if (step == 0) and is_classification_task:
                results['accuracy_before'] = compute_accuracy(logits, targets)

            self.model.zero_grad()
            params = gradient_update_parameters(self.model, inner_loss,
                step_size=step_size, params=params,
                first_order=(not self.model.training) or first_order)

        return params, results
                
    def train(self, dataloader):
        if self.optimizer is None:
            raise RuntimeError('Trying to call `train_iter`, while the '
                'optimizer is `None`. In order to train `{0}`, you must '
                'specify a Pytorch optimizer as the argument of `{0}` '
                '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        self.model.train()
        train_loss = []
        for batch in dataloader:

            self.optimizer.zero_grad()
            batch = tensors_to_device(batch, device=self.device)
            outer_loss, loss= self.get_outer_loss(batch)
            train_loss.append(loss)
            outer_loss.backward()
        
            self.optimizer.step()
        return sum(train_loss)/len(train_loss) if train_loss else 0.0
    def evaluate(self, dataloader):
        self.model.eval()
        val_loss = []
        for batch in dataloader:
            batch = tensors_to_device(batch, device=self.device)
            outer_loss, loss = self.get_outer_loss(batch)
            val_loss.append(loss)
        return sum(val_loss) / len(val_loss) if val_loss else 0.0


MAML = ModelAgnosticMetaLearning

class FOMAML(ModelAgnosticMetaLearning):
    def __init__(self, model, optimizer=None, step_size=0.1,
                 learn_step_size=False, per_param_step_size=False,
                 num_adaptation_steps=1, scheduler=None,
                 loss_function=F.cross_entropy, device=None):
        super(FOMAML, self).__init__(model, optimizer=optimizer, first_order=True,
            step_size=step_size, learn_step_size=learn_step_size,
            per_param_step_size=per_param_step_size,
            num_adaptation_steps=num_adaptation_steps, scheduler=scheduler,
            loss_function=loss_function, device=device)


class iMAML:

    def __init__(self, model, optimizer=None, step_size=0.1,lamda=100,
                 num_adaptation_steps=1, beta=1.0, dyfac=0.2, scheduler=None,
                 loss_function=F.cross_entropy, n_cg=5,device=None):
        self.model = model.to(device=device)
        self.optimizer = optimizer
        self.inner_optimizer = torch.optim.SGD(self.model.parameters(), lr=step_size)
        self.lamb = lamda
        self.n_cg = n_cg
        self.num_adaptation_steps = num_adaptation_steps
        self.beta = beta
        self.dyfac = dyfac
        self.loss_function = loss_function
        self.device = device
        
    @torch.enable_grad()
    def inner_loop(self, fmodel, diffopt, train_input, train_target):
        
        train_logit = fmodel(train_input)
        inner_loss = self.loss_function(train_logit, train_target)
        diffopt.step(inner_loss)

        return None

    @torch.no_grad()
    def cg(self, in_grad, outer_grad, params):
        x = outer_grad.clone().detach()
        r = outer_grad.clone().detach() - self.hv_prod(in_grad, x, params)
        p = r.clone().detach()
        for i in range(self.n_cg):
            Ap = self.hv_prod(in_grad, p, params)
            alpha = (r @ r)/(p @ Ap)
            x = x + alpha * p
            r_new = r - alpha * Ap
            beta = (r_new @ r_new)/(r @ r)
            p = r_new + beta * p
            r = r_new.clone().detach()
        return self.vec_to_grad(x)
    
    def vec_to_grad(self, vec):
        pointer = 0
        res = []
        for param in self.model.parameters():
            num_param = param.numel()
            res.append(vec[pointer:pointer+num_param].view_as(param).data)
            pointer += num_param
        return res

    @torch.enable_grad()
    def hv_prod(self, in_grad, x, params):
        
        hv = torch.autograd.grad(in_grad, params, retain_graph=True, grad_outputs=x)
        hv = [h.contiguous() for h in hv]

        hv = torch.nn.utils.parameters_to_vector(hv).detach()
        return hv/self.lamb + x

    def get_outer_loss(self, batch, is_train):
        if 'test' not in batch:
            raise RuntimeError('The batch does not contain any test dataset.')

        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)

        mean_outer_loss = torch.tensor(0., device=self.device)
        outer_losses_all_tasks = []
        loss_id = []
        
        grad_list = []
        mean_outer_loss = torch.tensor(0., device=self.device)
        
        for task_id, (train_inputs, train_targets, test_inputs, test_targets) \
                in enumerate(zip(*batch['train'], *batch['test'])):
                
            with higher.innerloop_ctx(self.model, self.inner_optimizer, track_higher_grads=False) as (fmodel, diffopt):
                
                for step in range(self.num_adaptation_steps):
                    self.inner_loop(fmodel, diffopt, train_inputs, train_targets)
                
                train_logit = fmodel(train_inputs)
                in_loss = self.loss_function(train_logit, train_targets)
                
                test_logit = fmodel(test_inputs)
                outer_loss = self.loss_function(test_logit, test_targets)
                mean_outer_loss += outer_loss.item()

                if is_train:
                    params = list(fmodel.parameters(time=-1))
                    in_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(in_loss, params, create_graph=True))
                    outer_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(outer_loss, params))
                    implicit_grad = self.cg(in_grad, outer_grad, params)
                    grad_list.append(implicit_grad)
                    loss_id.append((outer_loss.item(), task_id))
        
        mean_outer_loss.div_(num_tasks)
        sorted_task_ids = [task_id for _, task_id in sorted(loss_id, key=lambda x: x[0])]

        if is_train:
            self.optimizer.zero_grad()
            weight = torch.ones(len(grad_list))
            for i, task_id in enumerate(sorted_task_ids):
                if i < self.dyfac:
                    weight[task_id] *= self.beta

            weight /= weight.sum()
            grad = mix_grad(grad_list, weight)
            apply_grad(self.model, grad)
            self.optimizer.step()
            
            return mean_outer_loss
        else:
            return mean_outer_loss
    def train(self, dataloader):
        if self.optimizer is None:
            raise RuntimeError('Trying to call `train_iter`, while the '
                'optimizer is `None`. In order to train `{0}`, you must '
                'specify a Pytorch optimizer as the argument of `{0}` '
                '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        self.model.train()
        train_loss = []
        for batch in dataloader:
            
            self.optimizer.zero_grad()
            batch = tensors_to_device(batch, device=self.device)
            loss= self.get_outer_loss(batch,True)
            train_loss.append(loss)

        return sum(train_loss)/len(train_loss) if train_loss else 0.0
    def evaluate(self, dataloader):
        self.model.eval()
        val_loss = []
        for batch in dataloader:
            batch = tensors_to_device(batch, device=self.device)
            loss = self.get_outer_loss(batch, False)
            val_loss.append(loss)
        return sum(val_loss) / len(val_loss) if val_loss else 0.0
    
class MetaMinibatchProx:
    def __init__(self, model, optimizer=None,lamda=0.1, step_size=0.1, meta_lr=0.001, 
                 num_adaptation_steps=15, beta=1.0, dyfac=0.2,loss_function=F.cross_entropy, device=None):
        self.model = model.to(device=device)
        self.optimizer = optimizer
        self.step_size = step_size
        self.meta_lr = meta_lr
        self.beta = beta
        self.dyfac = dyfac
        self.lamda = lamda
        self.num_adaptation_steps = num_adaptation_steps
        self.loss_function = loss_function
        self.device = device
        
    def get_outer_parameter(self, batch):
    
        _, targets = batch
        num_tasks = targets.size(0)
        is_classification_task = (not targets.dtype.is_floating_point)
        loss_id = []
        task_params = []
        for task_id in range(num_tasks):
            inputs, targets = batch[0][task_id].squeeze(dim=0),batch[1][task_id].squeeze(dim=0)
            params, losses = self.adapt(inputs, targets,
                is_classification_task=is_classification_task)
            task_params.append(params)
            loss_id.append((losses,task_id))
        sorted_task_ids = [task_id for _, task_id in sorted(loss_id, key=lambda x: x[0])]
        cou = 0
        accumulated_params = OrderedDict({name: torch.zeros_like(param) for name, param in self.model.meta_named_parameters()})
        for i, task_id in enumerate(sorted_task_ids):
            params = task_params[task_id]
            if i < self.dyfac:
                for name, param in params.items():
                    accumulated_params[name] += param *self.beta
            else:
                for name, param in params.items():
                    accumulated_params[name] += param 
                cou +=1
                
        average_params = OrderedDict({name:param / (cou + (num_tasks-cou) * self.beta) for name, param in accumulated_params.items()})
        return average_params

    def get_outer_loss(self, batch):
        if 'test' not in batch:
            raise RuntimeError('THe batch does not contain any test dataset')
        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)
        is_classification_task = (not test_targets.dtype.is_floating_point)

        mean_outer_loss = torch.tensor(0., device=self.device)

        for task_id, (train_inputs, train_targets, test_inputs, test_targets) \
                in enumerate(zip(*batch['train'], *batch['test'])):
            params, _ = self.adapt(train_inputs, train_targets,
                is_classification_task=is_classification_task)
            
            with torch.no_grad():
                test_logits = self.model(test_inputs, params=params)
                outer_loss = self.loss_function(test_logits, test_targets)
                mean_outer_loss += outer_loss.item()
        
        mean_outer_loss.div_(num_tasks)
        return mean_outer_loss
        
    def adapt(self, inputs, targets, is_classification_task=None):

        if is_classification_task is None:

            is_classification_task = not targets.dtype.is_floating_point

        params = OrderedDict((name, param.to(self.device)) for name, param in self.model.meta_named_parameters())
        meta_params = OrderedDict((name, param.to(self.device)) for name, param in self.model.state_dict().items())

        meta_params_tensor = torch.cat([meta_params[name].view(-1) for name in params.keys()])

        for step in range(self.num_adaptation_steps):
            logits = self.model(inputs, params=params)
            inner_loss = self.loss_function(logits, targets)
            params_tensor = torch.cat([param.view(-1) for param in params.values()])

            proximal_loss = F.mse_loss(params_tensor, meta_params_tensor, reduction='sum') * self.lamda
            
            total_loss = inner_loss + proximal_loss
            self.model.zero_grad()

            params = gradient_update_parameters(self.model,total_loss,step_size=self.step_size,params=params,first_order=True)

        params_tensor = torch.cat([param.view(-1) for param in params.values()])
        proximal_loss = F.mse_loss(params_tensor, meta_params_tensor, reduction='sum') * self.lamda
        losses = self.loss_function(self.model(inputs,params=params),targets) + proximal_loss
        return params, losses

    def train(self, dataloader,evaluate_dataloader,epoch):    
        self.model.train()
        meta_lr = self.meta_lr
        train_loss = []
        for batch in dataloader:
            batch = tensors_to_device(batch, device=self.device)
            adption_params = self.get_outer_parameter(batch)
            for (name, meta_param), (adaption_name, adaption_param) in zip(self.model.named_parameters(), adption_params.items()):
                meta_param.data -= meta_lr * self.lamda * (meta_param.data - adaption_param.data)
        
        return self.evaluate(evaluate_dataloader)

    def evaluate(self, dataloader):
        self.model.eval()
        val_loss = []
        for batch in dataloader:
            batch = tensors_to_device(batch, device=self.device)
            outer_loss = self.get_outer_loss(batch)
            val_loss.append(outer_loss)
        return sum(val_loss)/len(val_loss) if val_loss else 0.0

class FOMuML:
    def __init__(self, model, optimizer=None,lamda=0.1, step_size=0.1, 
                 num_adaptation_steps=15, beta=1.0, dyfac=0.2, loss_function=F.cross_entropy, device=None):
        self.model = model.to(device=device)
        self.optimizer = optimizer
        self.step_size = step_size
        self.lamda = lamda
        self.num_adaptation_steps = num_adaptation_steps
        self.beta = beta
        self.dyfac = dyfac
        self.loss_function = loss_function
        self.device = device

    def get_outer_loss(self, batch):
        if 'test' not in batch:
            raise RuntimeError('The batch does not contain any test dataset.')

        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)
        is_classification_task = (not test_targets.dtype.is_floating_point)

        mean_outer_loss = torch.tensor(0., device=self.device)
        outer_losses_all_tasks = []
        loss_id = []
        meta_params = OrderedDict(self.model.state_dict()) 
        for task_id, (train_inputs, train_targets, test_inputs, test_targets) \
                in enumerate(zip(*batch['train'], *batch['test'])):
            params = self.adapt(train_inputs, train_targets,
                is_classification_task=is_classification_task)
            
            with torch.set_grad_enabled(self.model.training):
                test_logits = self.model(test_inputs, params=params)
                outer_loss = self.loss_function(test_logits, test_targets)
                proximal_loss=0.0
                for name, param in params.items():
                    meta_param = meta_params[name]
                    proximal_loss += self.lamda * torch.norm(param - meta_param, p=2) ** 2
                outer_loss += proximal_loss
                mean_outer_loss += outer_loss.item()
                loss_id.append((outer_loss.item(), task_id))
                outer_losses_all_tasks.append(outer_loss)
        
        mean_outer_loss.div_(num_tasks)

        sorted_task_ids = [task_id for _, task_id in sorted(loss_id, key=lambda x: x[0])]

        reweight_mean_outer_loss = torch.tensor(0., device=self.device)
        cou = 0

        for i, task_id in enumerate(sorted_task_ids):
            if i < self.dyfac:
                reweight_mean_outer_loss += outer_losses_all_tasks[task_id] * self.beta
            else:
                reweight_mean_outer_loss += outer_losses_all_tasks[task_id] 
                cou +=1
        reweight_mean_outer_loss.div_(cou+(num_tasks-cou)*self.beta)
        return reweight_mean_outer_loss, mean_outer_loss

    def adapt(self, inputs, targets, is_classification_task=None):

        if is_classification_task is None:

            is_classification_task = not targets.dtype.is_floating_point

        params = OrderedDict((name, param.to(self.device)) for name, param in self.model.meta_named_parameters())
        meta_params = OrderedDict((name, param.to(self.device)) for name, param in self.model.state_dict().items())

        meta_params_tensor = torch.cat([meta_params[name].view(-1) for name in params.keys()])

        for step in range(self.num_adaptation_steps):
            logits = self.model(inputs, params=params)
            inner_loss = self.loss_function(logits, targets)
            params_tensor = torch.cat([param.view(-1) for param in params.values()])

            proximal_loss = F.mse_loss(params_tensor, meta_params_tensor, reduction='sum') * self.lamda
            
            total_loss = inner_loss + proximal_loss
            self.model.zero_grad()

            params = gradient_update_parameters(self.model,total_loss,step_size=self.step_size,params=params,first_order=True)

        return params
                
    def train(self, dataloader):
        if self.optimizer is None:
            raise RuntimeError('Trying to call `train_iter`, while the '
                'optimizer is `None`. In order to train `{0}`, you must '
                'specify a Pytorch optimizer as the argument of `{0}` '
                '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        self.model.train()
        train_loss = []
        for batch in dataloader:

            self.optimizer.zero_grad()

            batch = tensors_to_device(batch, device=self.device)
            outer_loss, loss = self.get_outer_loss(batch)
            train_loss.append(loss)
            outer_loss.backward()
            self.optimizer.step()

        return sum(train_loss)/len(train_loss) if train_loss else 0.0
    def evaluate(self, dataloader):
        self.model.eval()
        val_loss = []
        for batch in dataloader:
            batch = tensors_to_device(batch, device=self.device)
            outer_loss, loss = self.get_outer_loss(batch)
            val_loss.append(loss)
        return sum(val_loss) / len(val_loss) if val_loss else 0.0
