import copy
import warnings
import math
from copy import deepcopy

import torch
import torch.nn as nn
from torch.nn import Module
from torch.optim.lr_scheduler import _LRScheduler
from .time2vec import SineActivation, CosineActivation
import pdb
import random
from .meter_utils import MetricMeter

import torch.nn.functional as F

class T2AMLP(Module):
  '''
    Multilayer Perceptron.
  '''
  def __init__(self, input_size = 1, output_size = 10, tsr = 1.0, mlp_hid_size = 64):
    super().__init__()
    self.flat = nn.Flatten(),
    self.fc1 = nn.Linear(input_size, mlp_hid_size)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(mlp_hid_size, output_size)
    self.tsr = tsr

  def forward(self, x):
    if isinstance(x, list) or isinstance(x, tuple):
        x = x[0]

    if len(x.shape) > 1:
        x = [self.flat(torch.cat(x))]

    x = self.tsr * x
    x = self.fc1(x)
    x = self.relu(x)
    output = self.fc2(x)

    return output

class Time2Vec(Module):
    def __init__(self, input_size = 1, output_size = 10):
        super().__init__()
        self.cos_act = SineActivation(input_size, output_size)
        self.sin_act = CosineActivation(input_size, output_size)
    def forward(self, t):
        assert len(t) == 1
        t = t[0]
        return torch.cat([self.cos_act(t), self.sin_act(t)])


class WAveragedModel(Module):
    def filter(self, model):
        if hasattr(model, "get_forward_model"):
            model = model.get_forward_model()
        return model
    def __init__(self, model_ls, device=None, acp_type = 'mlp', use_t2v = False, t2v_size = 64, tsr = 1.0, update_iter = 1,opt = None, hypermeters = None):
        super(WAveragedModel, self).__init__()
        model = self.filter(model_ls[0])
        self.module = deepcopy(model)
        self.opt = opt
        self.device = device
        self.avg_num = 0
        self.tsr = tsr
        self.real_random = False
        self.hypms = hypermeters
        self.output_search = hypermeters['output_search']
        self.model_ls = [self.filter(ms) for ms in model_ls]
        self.show_weights = False
        for module in self.model_ls:
            module.eval()
        self.param_ls = []

        for model in self.model_ls:
            model = model.to(self.device)
            self.param_ls.append(list(model.parameters()))

        for model_params in self.param_ls:
            for param in model_params:
                param.requires_grad = False
        self.acp_type, self.use_t2v, self.t2v_size = acp_type, use_t2v, t2v_size
        arch_net_param_ls = []
        arch_net_param_ls = self.init_t2v(arch_net_param_ls)
        self.update_iter = update_iter
        self.arch_iters = 0
        assert self.acp_type == 'mlp'

        self.arch_net_param_ls = self.init_arch_mlp(arch_net_param_ls)
        self.arch_optimizer = torch.optim.Adam(self.arch_net_param_ls, lr=self.opt.lr, betas=(self.opt.beta1, 0.999), weight_decay=self.opt.weight_decay)
        
        if device is not None:
            self.module = self.module.to(device)

        self.register_buffer("n_averaged", torch.tensor(0, dtype=torch.long, device=device))

    def forward(self, t, *args, **kwargs):
        if self.t2v_fun:
            t_feat = self.t2v_fun(t)
        else:
            t_feat = t
        arch_params = nn.softmax(self.acn_fun(t_feat))
        self.avg_fn(arch_params)
        return self.module(*args, **kwargs)

    def init_arch_mlp(self, arch_net_param_ls):
        self.acn_fun = T2AMLP(self.input_size, len(self.param_ls), self.tsr, self.hypms['mlp_hid_size']).to(self.device)
        arch_net_param_ls += self.acn_fun.parameters()
        return arch_net_param_ls

    def init_t2v(self, arch_net_param_ls):
        if self.use_t2v:
            self.input_size = self.t2v_size * 2
            self.t2v_fun = Time2Vec(1, self.t2v_size).to(self.device)
            arch_net_param_ls += self.t2v_fun.parameters()
        else:
            self.input_size = 1
            self.t2v_fun = None
        return arch_net_param_ls
    
    def prepare_param_tensor(self, single_net_params, device):
        # print("input tensor shape", single_net_params.shape)
        snp_o = single_net_params.detach().to(device)
        snp_o.requires_grad = False
        return snp_o

    def avg_fn(self, arch_params):
        if self.show_weights:
            print(arch_params)
        base_mods_params = list(zip(*self.param_ls))
        tgt_mod_params = list(self.module.parameters())
        for avg_param, param_ls in zip(tgt_mod_params, base_mods_params):
            device = avg_param.device
            param_ls = [self.prepare_param_tensor(sn_params, device) for sn_params in param_ls]
            wavg_params = 0
            for param, acp in zip(param_ls, arch_params):
                wavg_params = wavg_params + acp * param
            avg_param.detach().copy_(wavg_params)
    
    def arch_step(self, t, *args, **kwargs):
        if self.show_weights:
            print("searching!!!")
        self.module.do_step = False
        self.avg_num += 1
        if self.t2v_fun:
            t_feat = self.t2v_fun(t)
        else:
            t_feat = t
        
        arch_params = F.softmax(self.acn_fun(t_feat))
        self.arch_iters += 1
        assert self.output_search

        with torch.no_grad():
            outputs = []
            for module in self.model_ls:
                outputs.append(module.single_pred_meta(*args, **kwargs))
        final_pred = sum([arch_params[ido] * outputs[ido][0] for ido in range(len(outputs))])
        all_y = outputs[0][1]
        arch_loss = self.model_ls[0].loss_meta(final_pred, all_y)
        val_loss = {}

        arch_loss.backward()
        if self.arch_iters % self.update_iter == 0:
            self.arch_optimizer.step()
            self.arch_optimizer.zero_grad()
        val_loss['Arch Loss'] = arch_loss.item()
        return val_loss

    def predict(self, t, *args, **kwargs):
        if self.t2v_fun:
            t_feat = self.t2v_fun(t)
        else:
            t_feat = t
        arch_params = F.softmax(self.acn_fun(t_feat))
        self.avg_fn(arch_params)

        return self.module.predict_meta(*args, **kwargs)
    
    def get_target_model(self, t):
        if self.t2v_fun:
            t_feat = self.t2v_fun(t)
        else:
            t_feat = t
        arch_params = F.softmax(self.acn_fun(t_feat))
        self.avg_fn(arch_params)

        return self.module

    @property
    def network(self):
        return self.module.network

    def clone(self):
        clone = copy.deepcopy(self.module)
        clone.optimizer = clone.new_optimizer(clone.network.parameters())
        return clone


def sample_best_meta_model(save_Q, algorithm, n_save = 8, step = 0, acc = -1):
    if hasattr(algorithm, "get_forward_model"):
        model = algorithm.get_forward_model()
    else:
        model = copy.deepcopy(algorithm).cpu()
    model.step = str(step)
    model.acc = acc
    if len(save_Q[0]) < n_save:
        save_Q[0].append(model)
        save_Q[1].append(acc)
        print("cur accs", save_Q[1])
    else:
        min_acc = min(save_Q[1])
        if acc > min_acc:
            min_idx = save_Q[1].index(min_acc)
            print("replacing the prev model with acc", min_acc, "step", save_Q[0][min_idx].step, "to the cur one with acc", acc, 'cur step', step)
            save_Q[1][min_idx] = acc
            save_Q[0][min_idx] = model
    return save_Q

def arch_opt(meta_swad_algorithm, train_arch_loaders, arch_steps, show_step = 50, device = 'cpu'):
    arch_minibatches_iterator = zip(*train_arch_loaders)
    search_res = MetricMeter()
    for step in range(arch_steps):

        if (step + 1) % show_step == 0:
            print("Step", str(step + 1), str(search_res))
            search_res.reset()

        batches_dictlist = next(arch_minibatches_iterator)
        batch_indxs = [idb for idb in range(len(batches_dictlist))]
        random.shuffle(batch_indxs)
        
        for bidx in batch_indxs:
            batches_or = batches_dictlist[bidx]
            x, y, d = batches_or
            # minibatches_device = [(x.to(device), y.to(device), d.to(device))
            #                       for x, y, d in next(train_minibatches_iterator)]
            cur_batch = {'x': x.to(device).unsqueeze(0), 'y': y.to(device).unsqueeze(0), 'd': d.to(device).unsqueeze(0)}
            ts_vecs = cur_batch['d']
            
            inputs = {**cur_batch, "step": step}
            assert len(ts_vecs) == 1
            val_loss = meta_swad_algorithm.arch_step(ts_vecs, **inputs)
            search_res.update(val_loss)
    