import pathlib
import os
import random

import torch
import torch.nn as nn
import numpy as np
import inspect, re

from typing import Any, Optional, Tuple
from torch.autograd import Function


class MLP(nn.Module):
    # It's a 2 layer neural network
    # It's used for theta(extract feature), phi(classifier), local_match(align the feature)
    def __init__(self, h_list):
        super(MLP, self).__init__()

        self.fm1 = nn.Linear(h_list[0], h_list[1])
        self.fm2 = nn.Linear(h_list[1], h_list[2])
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        x = torch.relu(self.fm1(x))
        z = self.fm2(x)
        return z


class LSTM(nn.Module):
    def __init__(self, input_dim=310, output_dim=256, layers=2, location=-1):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, output_dim, layers, batch_first=True)
        self.location = location
    def forward(self, x):
        self.lstm.flatten_parameters()
        feature, (hn, cn) = self.lstm(x)
        return feature[:, self.location, :]


class GradientReverseFunction(Function):
    """
    define the forward and backward
    """
    @staticmethod
    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
        ctx.coeff = coeff
        output = input * 1.0
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        return grad_output.neg() * ctx.coeff, None


class GRL(nn.Module):
    def __init__(self):
        super(GRL, self).__init__()

    def forward(self, *input):
        return GradientReverseFunction.apply(*input)



class DomainMetric(nn.Module):
    # It is the omega1 in our algorithm. It s higher version of Local_match_network
    def __init__(self, h_list):
        super(DomainMetric, self).__init__()
        self.fc1 = nn.Linear(h_list[0], h_list[1])
        self.fc2 = nn.Linear(h_list[1], h_list[2])
        for m in [self.fc1, self.fc2]:
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        self.grl = GRL()

    def forward(self, x):
        x = self.grl(x)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = torch.mean(x, 0, keepdims=True)
        x = self.grl(x)
        return x


class DomainShift(nn.Module):
    # omega2
    def __init__(self, h):
        super(DomainShift, self).__init__()
        self.f = nn.Linear(h, h)
        nn.init.eye_(self.f.weight)
        self.f.weight.requires_grad = False

    def forward(self, x):
        x = torch.mean(x, 0, keepdims=True)
        x = self.f(x)
        return x.norm()


def batch_index_generator(total_size, batch_size):
    # Use np.random.permutation to generate batch
    shuffle_total = np.random.permutation(range(total_size))
    return shuffle_total[:batch_size]


def record_acc(acc_file_path, *values):
    # record meaningful message
    print('--------')
    with open(acc_file_path, mode='a') as f:
        for x in values:
            f.write(x)
            f.write('\t')
            print(x, end='\t')
        f.write('\n')
        print()


def varname(p):
    for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
        m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
    if m:
        return m.group(1)

class SaveModel:
    def __init__(self, prefix, path_dir, model_dict):
        self.models = model_dict
        self.prefix = prefix
        self.path_dir = path_dir
        pathlib.Path(path_dir).mkdir(parents=True, exist_ok=True)

    def save(self):
        for model_name, model in self.models.items():
            model_path = os.path.join(self.path_dir, self.prefix + model_name + '.ptr')
            torch.save(model, model_path)

class SetHyperParameter:
    def __init__(self):
        self.i = 0
        self.lr = {'theta': 0.01, 'omega': 0.01, 'phi': 0.01, 'local_match': 0.01, 'theta_second': 0.01,
                    'ft_theta': 0.01, 'ft_phi': 0.01}
        self.proportion = {'held_out': 0.01, 'local_match': 1e-8, 'train_dg': 0.001}
        self.ite = {'outer': 10, 'inner': 1, 'fine_tune': 3}
        self.weight_decay = {'theta': 0.01, 'omega': 0.01, 'phi': 0.01, 'local_match': 0.01}
        self.batch_size = 1024

        self.model_path = 'SDS_model\\{}'.format(self.i)
        self.log_path = 'logs\SDS_logs\\{}'.format(self.i)

        self.seed = None

    def write_log(self, acc):
        return "#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#\n" \
               "The final logging of train process {} --\n" \
               "Batch_size: {} --\n" \
               "weight_decay: \n{} --\n" \
               "iteration times: {} --\n" \
               "loss proportion (the main train loss proportion is 1): \n{} --\n" \
               "learning rate: \n{} --\n\n" \
               "acc: (* 1%) {} --\n" \
               "average_acc: {}\n\n\n".format(self.i, self.batch_size, self.weight_decay, self.ite, self.proportion,
                                self.lr, acc, sum(acc)/len(acc))


def fix_nn(model, theta):
    def k_param_fn(tmp_model, name=None):
        if len(tmp_model._modules)!=0:
            for(k,v) in tmp_model._modules.items():
                if name is None:
                    k_param_fn(v, name=str(k))
                else:
                    k_param_fn(v, name=str(name+'.'+k))
        else:
            for (k,v) in tmp_model._parameters.items():
                if not isinstance(v, torch.Tensor):
                    continue
                tmp_model._parameters[k] = theta[str(name + '.' + k)]

    k_param_fn(model)
    return model


def write_log(flags, acc):
    return "#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#\n" \
            "The final logging of train process {} --\n" \
            "Batch_size: {} --\n" \
            "weight_decay: \n{} --\n" \
            "iteration times: {} --\n" \
            "loss proportion (the main train loss proportion is 1): \n{} --\n" \
            "learning rate: \n{} --\n\n" \
            "acc: (* 1%) {} --\n" \
            "average_acc: {}\n\n\n".format(flags.i, flags.batch_size, flags.weight_decay, flags.ite, flags.proportion,
                            flags.lr, acc, sum(acc)/len(acc))

