from os import environ
import torch
from torch._C import device
from models.initializer import initialize_model
from algorithms.single_model_algorithm import SingleModelAlgorithm
from wilds.common.utils import split_into_groups
import torch.autograd as autograd
from wilds.common.metrics.metric import ElementwiseMetric, MultiTaskMetric
from optimizer import initialize_optimizer

import learn2learn as l2l
import copy
from torch.cuda.amp import autocast

import torch.nn as nn
import torch.nn.functional as F

def shift(seq, n):
    n = n % len(seq)
    return seq[n:] + seq[:n]

class BLOOD(SingleModelAlgorithm):
    def __init__(self, config, d_out, grouper, loss, metric, n_train_steps, is_group_in_train):
        self.device = config.device
        model = initialize_model(config, d_out).to(config.device)
        
        # For bi-level optimization
        model = l2l.algorithms.MAML(model, lr=config.lr, first_order=config.first_order)

        super().__init__(
            config=config,
            model=model,
            grouper=grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps
        )
        self.logged_fields.append('group_weight')
        
        self.config = config 

        self.group_weights = torch.zeros(grouper.n_groups*2)
        self.group_weights[is_group_in_train.repeat_interleave(2)] = 1

        self.group_weights = self.group_weights/self.group_weights.sum()
        self.group_weights = self.group_weights.to(self.device)

        assert isinstance(self.loss, ElementwiseMetric) or isinstance(self.loss, MultiTaskMetric)


    def process_batch(self, batch):
        """
        A helper function for update() and evaluate() that processes the batch
        Args:
            - batch (tuple of Tensors): a batch of data yielded by data loaders
        Output:
            - results (dictionary): information about the batch
                - g (Tensor)
                - y_true (Tensor)
                - metadata (Tensor)
                - loss (Tensor)
                - metrics (Tensor)
              all Tensors are of size (batch_size,)
        """
        results = super().process_batch(batch)
        results['group_weight'] = self.group_weights
        return results

    def objective(self, results):
        unique_groups, group_indices, _ = split_into_groups(results['g'])
        n_groups_per_batch = unique_groups.numel()
        outer_loss_list = []
        
        outer_loss_list = torch.zeros(self.grouper.n_groups*2, device=self.device)
        for_outer = copy.deepcopy(group_indices)
        
        if len(for_outer) > 2:
            shuffle_index = torch.randint(1, len(for_outer)-1, (1,)).item()
        elif len(for_outer) == 2:
            shuffle_index = 1
        
        for_outer = shift(for_outer, shuffle_index)

        env_idx = 0
        for group_train, group_eval in zip(group_indices, for_outer):
            learner = self.model.clone()
            for _ in range(1):
                outputs = learner(results['x'][group_train])
                group_losses, _ = self.loss.compute_flattened(
                    outputs,
                    results['y_true'][group_train],
                    return_dict=False)
                inner_loss = group_losses.mean()
                learner.adapt(inner_loss)
            
            outputs = learner(results['x'][group_eval])
            group_losses, _ = self.loss.compute_flattened(
                outputs,
                results['y_true'][group_eval],
                return_dict=False)

            preds = torch.argmax(outputs, dim=-1).float()

            correct_idx = preds == results['y_true'][group_eval]
            wrong_idx = preds != results['y_true'][group_eval]

            results['y_pred'][group_eval] = outputs

            correct_loss = group_losses[correct_idx].mean()
            if not correct_loss.isnan():
                outer_loss_list[unique_groups[(env_idx-shuffle_index)%len(unique_groups)]*2] = correct_loss
            wrong_loss = group_losses[wrong_idx].mean()
            if not wrong_loss.isnan():
                outer_loss_list[unique_groups[(env_idx-shuffle_index)%len(unique_groups)]*2+1] = wrong_loss
            env_idx += 1

        self.group_weights = self.group_weights * torch.exp(1e-2 * outer_loss_list.data)
        self.group_weights = (self.group_weights/(self.group_weights.sum()))
        
        results['group_weight'] = self.group_weights[::2] + self.group_weights[1::2]

        total_outer_loss = outer_loss_list @ self.group_weights
        total_outer_loss = total_outer_loss * (outer_loss_list.size(0)/(outer_loss_list != 0).sum()).item()

        return total_outer_loss