import scipy.io as scio
import torchvision
import torch
from torch import nn 
from torch.utils.data import Dataset,DataLoader,TensorDataset
from torchvision import datasets, transforms
import time
import numpy as np 
import pandas as pd
import random
import math
from torch.nn import functional as F
import csv
import copy

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import sys, random, time
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import hypergrad as hg

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import torch.multiprocessing as mp

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
seed = 200
setup_seed(seed)

filename_list_whole=["../../ref_traj/"+'M'+"_reftraj.mat" ]*50 + ["../../ref_traj/"+'E'+"_reftraj.mat" ]*50 + ["../../ref_traj/"+'T'+"_reftraj.mat" ]*50 + ["../../ref_traj/"+'A'+"_reftraj.mat" ]*51
center_list_whole=np.random.normal(0, 1, [len(filename_list_whole),2])


batch_size_K = 400
batch_size_outer = 400
meta_lambda=200.0
n_epochs = 100
n_inner_level_epochs=100

redius=2.0
less=False
weight=500.0
softplus_para=200.0


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.params = torch.nn.ParameterList([
                    torch.Tensor(128, 8).uniform_(-1./math.sqrt(8), 1./math.sqrt(8)).requires_grad_(),
                    torch.Tensor(128).zero_().requires_grad_(),

                    torch.Tensor(128, 128).uniform_(-1./math.sqrt(128), 1./math.sqrt(128)).requires_grad_(),
                    torch.Tensor(128).zero_().requires_grad_(),

                    torch.Tensor(128, 128).uniform_(-1./math.sqrt(128), 1./math.sqrt(128)).requires_grad_(),
                    torch.Tensor(128).zero_().requires_grad_(),

                    torch.Tensor(2, 128).uniform_(-1./math.sqrt(128), 1./math.sqrt(128)).requires_grad_(),
                    torch.Tensor(2).zero_().requires_grad_(),
                ])

    def dense(self, x, params):
        y = F.linear(x, params[0], params[1])
        y = F.relu(y)

        y = F.linear(y, params[2], params[3])
        y = F.relu(y)

        y = F.linear(y, params[4], params[5])
        y = F.relu(y)

        y = F.linear(y, params[6], params[7])

        return y

    def input_process(self, x):
        x2=torch.pow(x, 2)
        x3=torch.pow(x, 3)
        x4=torch.pow(x, 4)
        x_sin=torch.sin(x*3.14)
        x_cos=torch.cos(x*3.14)
        x_sin_2=torch.sin(2*x*3.14)
        x_cos_2=torch.cos(2*x*3.14)
        return torch.cat((x,x2,x3,x4,x_sin,x_cos,x_sin_2,x_cos_2), 1)

    def forward(self, x, params):
        v = torch.ones(x.shape,dtype=torch.float).to(device) 
        position=self.dense(self.input_process(x), params)*10.0 
        position1,position2=position.split([1,1],dim=1) 
        vel1=torch.autograd.grad(position1,x,v,retain_graph=True, create_graph=True)[0]
        vel2=torch.autograd.grad(position2,x,v,retain_graph=True, create_graph=True)[0]
        return torch.cat((position1,position2,vel1,vel2), 1) 
    
    def forward1(self, x, params):
        position=self.dense(self.input_process(x), params)*10.0
        return position

def my_mse_loss(outputs, Q, Sigma):
    a=outputs - Q
    a=torch.reshape(a,(-1,4,1))
    b=torch.reshape(a,(-1,1,4))
    #print(a.shape)
    #print(b.shape)
    #print(Sigma.shape)
    return torch.mean(torch.matmul(torch.matmul(b,torch.inverse(Sigma)),a))

def constraint_voilations(outputs, center=[0.0,0.0], less=less, redius=redius, weight=weight):
    position,vel=outputs.split([2,2],dim=1)
    center_tensor=torch.tensor(center, dtype= torch.float).to(device)
    constraint_voilations=0.0
    if less:
        constraint_voilations= (F.softplus((torch.norm(position-center_tensor,dim=1)- redius),softplus_para)-0.001)*weight
    else:
        constraint_voilations= (F.softplus((-torch.norm(position-center_tensor,dim=1)+ redius),softplus_para)-0.001)*weight
    return torch.mean(constraint_voilations)

def bias_reg(params,meta_parameter, lambada=meta_lambda):
    theta_prime = [(params[i] - meta_parameter[i]) for i in range(len(params))]
    bias_reg_loss=0.0
    for i in range(len(params)):
        bias_reg_loss+=torch.norm(theta_prime[i])*torch.norm(theta_prime[i])
    return bias_reg_loss*lambada

def adjust_learning_rate(optimizer, epoch, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr 

def inner_loop(hparams, params, optim, n_steps=50, create_graph=False):
    params_history = [optim.get_opt_params(params)]

    for t in range(n_steps):
        params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))

    return params_history


def inner_loop_my(hparams, params, loss, n_steps=50, create_graph=False):

    learning_rate0=0.001
    optimizer0 = torch.optim.Adam(params,lr=learning_rate0,weight_decay=0.0)
    for i in range(n_steps):
        loss_train1=loss(params,hparams)
        optimizer0.zero_grad()
        loss_train1.backward(retain_graph=True)
        optimizer0.step()
    return params


def run_for_each_round(model, start_round, cur_round, expert_id, mp_queue):
    filename_list=filename_list_whole[start_round:cur_round]
    test_file_name_list=filename_list_whole[cur_round:cur_round+1]
    whole_task_num=len(filename_list) 
    task_test_num=1
    center_list_train=center_list_whole[start_round:cur_round,:]
    center_list_test=center_list_whole[cur_round:cur_round+1,:]

    t_data_list=[]
    y_data_list=[]
    sigma_data_list=[]

    t_data_test_list=[]
    y_data_test_list=[]
    sigma_data_test_list=[]

    whole_task_num=len(filename_list) 
    task_test_num=len(test_file_name_list)

    for filename in filename_list:
        t_data=[]
        y_data=[]
        sigma_data=[]
        file_data=scio.loadmat(filename)['refTraj'][0]
        for data in file_data:
            t_data.append([data[0][0][0]-1.0])
            y_data.append([data[1][0][0],data[1][1][0],data[1][2][0],data[1][3][0]])
            sigma_data.append(data[2]+0.001*np.identity(4))
        t_data_list.append(np.array(t_data))
        y_data_list.append(np.array(y_data))
        sigma_data_list.append(np.array(sigma_data))

    for filename in test_file_name_list:
        t_data=[]
        y_data=[]
        sigma_data=[]
        file_data=scio.loadmat(filename)['refTraj'][0]
        for data in file_data:
            t_data.append([data[0][0][0]-1.0])
            y_data.append([data[1][0][0],data[1][1][0],data[1][2][0],data[1][3][0]])
            sigma_data.append(data[2]+0.001*np.identity(4))
        t_data_test_list.append(np.array(t_data))
        y_data_test_list.append(np.array(y_data))
        sigma_data_test_list.append(np.array(sigma_data))


    learning_rate=0.001
    optimizer0 = torch.optim.Adam(model.params,lr=learning_rate,weight_decay=0.0000)

    if cur_round - start_round < 20:
        n_epochs=1
    else:
        n_epochs=2
    # add returned loss
    loss_meta = 0
    for epoch in range(n_epochs):
        if whole_task_num<10:
            task_num=whole_task_num
        else:
            task_num=10

        if epoch<10:
            n_inner_level_epochs=200
        else:
            n_inner_level_epochs=100

        number_list=random.sample(range(whole_task_num),task_num)
        t_data_list_thisepoch=[]
        y_data_list_thisepoch=[]
        sigma_data_list_thisepoch=[]
        center_list_train_thisepoch=[]
        for i in number_list:
            #print(filename_list[i])
            t_data_list_thisepoch.append(t_data_list[i])
            y_data_list_thisepoch.append(y_data_list[i])
            sigma_data_list_thisepoch.append(sigma_data_list[i])
            center_list_train_thisepoch.append(center_list_train[i])
        t_data_list_thisepoch=t_data_list_thisepoch+t_data_test_list
        y_data_list_thisepoch=y_data_list_thisepoch+y_data_test_list
        sigma_data_list_thisepoch=sigma_data_list_thisepoch+sigma_data_test_list
        center_list_train_thisepoch.extend(center_list_test)

        data_loader_train_list=[]
        data_loader_train_list_outer=[]
        data_loader_train_list_constraint=[]
        for i in range(len(t_data_list_thisepoch)):
            data_loader_train = torch.utils.data.DataLoader(TensorDataset(torch.tensor(t_data_list_thisepoch[i]).float().requires_grad_(),torch.tensor(y_data_list_thisepoch[i]).float(),torch.tensor(sigma_data_list_thisepoch[i]).float()),shuffle = True, batch_size = batch_size_K)
            data_loader_train_list.append(data_loader_train)
            data_loader_train1 = torch.utils.data.DataLoader(TensorDataset(torch.tensor(t_data_list_thisepoch[i]).float().requires_grad_(),torch.tensor(y_data_list_thisepoch[i]).float(),torch.tensor(sigma_data_list_thisepoch[i]).float()),shuffle = True, batch_size = batch_size_outer)
            data_loader_train_list_outer.append(data_loader_train1)
            data_loader_train2 = torch.utils.data.DataLoader(TensorDataset(torch.tensor(t_data_list_thisepoch[i]).float().requires_grad_(),torch.tensor(y_data_list_thisepoch[i]).float()),shuffle = False, batch_size = 400)
            data_loader_train_list_constraint.append(data_loader_train2)

        data_train=zip(*data_loader_train_list,*data_loader_train_list_outer)
        data_train_constraint=zip(*data_loader_train_list_constraint)
        data_train_constraint_now=list(data_train_constraint)[0]

        model.train()
        loss_train_sum = 0.0
        loss_no_grad_sum= 0.0
        loss_test_sum= 0.0
        loss_no_grad_sum_test=0
        loss_constraint_sum_test=0
        optimizer=optimizer0

        print(f'Epoch {epoch + 1}/{n_epochs}'.center(40,'-'))

        for step_train, data_train_now in enumerate(data_train):
            theta_prime_list=[]
            loss_meta_train_tensor=[]
            loss_meta_test_tensor=[]
            loss_no_grad=[]
            loss_no_grad_test=[]
            loss_constraint_meta_test_tensor=[]

            data_train_now_same=[[data_xy.detach().clone().requires_grad_() for data_xy in  data_loader_train_new] for data_loader_train_new in data_train_now]

            optimizer.zero_grad()

            for number, data_loader_train_now in enumerate(data_train_now):
                
                if number < task_num + task_test_num:
                    task_now=number
                    
                    (features, labels, sigmas)=data_loader_train_now
                    features = features.to(device)
                    labels = labels.to(device)
                    sigmas=sigmas.to(device)
                    outputs = model(features, model.params)

                    (features_constraint, labels_constraint)=data_train_constraint_now[task_now]
                    features_constraint = features_constraint.to(device)
                    outputs1 = model(features_constraint, model.params)

                    loss_train = my_mse_loss(outputs, labels, sigmas)
                    if number < task_num:
                        loss_train+=constraint_voilations( outputs1, center=center_list_train_thisepoch[task_now] )
                    else:
                        loss_train+=constraint_voilations( outputs1, center=center_list_test[task_now-task_num] )

                    if number<task_num:
                        loss_no_grad.append(loss_train.item())
                    else:
                        loss_no_grad_test.append(loss_train.item())

                    def loss_train_call(params, hparams):
                        return my_mse_loss(model(features, params), labels, sigmas)+bias_reg(params,hparams)+ constraint_voilations( model(features_constraint, params), center=center_list_train_thisepoch[task_now] )
                    
                    #(features_test, label_test,sigma_test)=data_train_now_same[task_now + task_num + task_test_num]
                    (features_test, label_test,sigma_test)=data_loader_train_now

                    features_test = features_test.to(device)
                    label_test=label_test.to(device)
                    sigma_test=sigma_test.to(device)
                    
                    def loss_val_call(params, hparams):
                        return my_mse_loss(model(features_test, params), label_test, sigma_test)
                    
                    #inner_opt_class = hg.GradientDescent
                    #inner_opt_kwargs = {'step_size': 0.00006}
                    #inner_opt=inner_opt_class(loss_train_call, **inner_opt_kwargs)

                    if number<task_num:
                        theta_tem = [p.detach().clone().requires_grad_(True) for p in model.params] 
                        theta_prime = inner_loop_my(model.params, theta_tem, loss_train_call, n_inner_level_epochs)
                        #theta_prime = inner_loop(model.params, theta_tem, inner_opt, n_inner_level_epochs)[-1]
                        theta_prime_list.append(theta_prime)

                        cg_fp_map = hg.GradientDescent(loss_f=loss_train_call, step_size=1.)  
                        hg.CG(theta_prime, list(model.params), K=5, fp_map=cg_fp_map, outer_loss=loss_val_call) 
                    else:
                        theta_tem = [p.detach().clone().requires_grad_(True) for p in model.params] 
                        theta_prime = inner_loop_my(model.params, theta_tem, loss_train_call, n_inner_level_epochs)
                        #theta_prime = inner_loop(model.params, theta_tem, inner_opt, n_inner_level_epochs)[-1]
                        theta_prime_list.append(theta_prime)
                    
                elif number>=task_num+task_test_num and number<2*task_num+task_test_num:
                    task_now=number-(task_num+task_test_num)
                    (features1, labels1,sigmas1)=data_loader_train_now
                    features1 = features1.to(device)
                    labels1 = labels1.to(device)
                    sigmas1= sigmas1.to(device)
                    outputs1=model(features1, theta_prime_list[task_now])

                    (features_constraint, labels_constraint)=data_train_constraint_now[task_now]
                    features_constraint = features_constraint.to(device)
                    outputs2 = model(features_constraint, theta_prime_list[task_now])
                    
                    current_loss=my_mse_loss(outputs1, labels1,sigmas1)+constraint_voilations(outputs2, center=center_list_train_thisepoch[task_now])
                    loss_meta_train_tensor.append(current_loss)
                
                elif number>=2*task_num+task_test_num:
                    task_now=number-(task_num+task_test_num)
                    (features1, labels1,sigmas1)=data_loader_train_now
                    features1 = features1.to(device)
                    labels1 = labels1.to(device)
                    sigmas1= sigmas1.to(device)
                    outputs1=model(features1, theta_prime_list[task_now])

                    (features_constraint, labels_constraint)=data_train_constraint_now[task_now]
                    features_constraint = features_constraint.to(device)
                    outputs2 = model(features_constraint, theta_prime_list[task_now])

                    current_loss=my_mse_loss(outputs1, labels1,sigmas1)+constraint_voilations(outputs2, center=center_list_train_thisepoch[task_now] )
                    loss_meta_test_tensor.append(current_loss)
                    loss_constraint_meta_test_tensor.append(constraint_voilations(outputs2, center=center_list_train_thisepoch[task_now] ))
                    
            loss_meta_train=sum(loss_meta_train_tensor)/float(task_num)

            #loss_meta_train.backward(retain_graph=  False)
            
            nan_list=[bool(torch.isnan(pa.grad).any()) for pa in model.params ]
            print(nan_list)
            if not bool(nan_list[0]):
                optimizer.step()

            loss_meta_test=sum(loss_meta_test_tensor)/float(task_test_num)
            loss_constraint_meta_test=sum(loss_constraint_meta_test_tensor)/float(task_test_num)

            loss_train_sum += loss_meta_train.item()
            loss_test_sum += loss_meta_test.item()
            loss_constraint_sum_test +=loss_constraint_meta_test.item()
            loss_no_grad_sum += sum(loss_no_grad)/(task_num)
            loss_no_grad_sum_test += sum(loss_no_grad_test)/(task_test_num)


            if (step_train+1) % 1 == 0:
                print(f'step = {step_train+1}, loss = {loss_train_sum / 1:.6f}')
                print(f'step = {step_train+1}, test_loss = {loss_test_sum / 1:.6f}')
                print('loss_no_grad:'+str(loss_no_grad_sum/ 1))
                print('loss_no_grad_test:'+str(loss_no_grad_sum_test/ 1))
                print('loss_constraint_test:'+str(loss_constraint_sum_test/ 1))   

                loss_meta += loss_test_sum

                loss_train_sum=0
                loss_no_grad_sum=0
                loss_test_sum=0
                loss_no_grad_sum_test=0
                loss_constraint_sum_test=0

    print("meta lambda:   "+str(meta_lambda))
    loss_meta = loss_meta / n_epochs
    ################ save model ################
    torch.save(model, './pkl/model_meta_'+str(cur_round)+'_expert_'+str(expert_id)+'_'+str(seed)+'.pkl') 
    mp_queue.put((expert_id,loss_meta))
    return loss_meta 


class Meta:
    """
        An abstract class for meta-algorithm: AdaNormalHedge.

    Args:
        prob (numpy.ndarray): Initial probability over the base-learners.
    """

    def __init__(self, prob: np.ndarray, N: int):
        self._prob = prob
        self._init_prob = self._prob.copy()
        self.t = 0
        self._R = np.zeros(N)
        self._C = np.zeros(N)
        self._w = np.zeros(N)

    def _Phi(self, R, C):
        R_plus = np.maximum(0, R)
        return np.exp(np.square(R_plus) / (3 * C))
	
    def _w_func(self, R, C):
        return 0.5 * (self._Phi(R + 1, C + 1) - self._Phi(R - 1, C - 1))

    def update_prob(self, loss_bases: np.ndarray, loss_meta):
        self.R += loss_meta - loss_bases
        self.C += np.abs(loss_meta - loss_bases)
        self.w = self._w_func(self.R, self.C)
        self.prob = self.init_prob * self.w
        self.prob /= np.sum(self.prob)

    def update_active_state(self, active_state):
        self._active_state = active_state
        self._active_index = np.where(active_state > 0)[0]
        re_init_idx = np.where(self._active_state == 2)[0]
        self._R[re_init_idx], self._C[re_init_idx] = 0, 0
        self._w[re_init_idx] = self._w_func(0, 0)
        self._prob[re_init_idx] = self._init_prob[re_init_idx] * self._w[re_init_idx]
        self.prob /= np.sum(self.prob)

    def sample_expert(self):
        return np.random.choice(self._active_index, p=self.prob)
        
    @property
    def w(self):
        return self._w[self._active_index]

    @w.setter
    def w(self, w):
        self._w[self._active_index] = w

    @property
    def R(self):
        return self._R[self._active_index]

    @R.setter
    def R(self, R):
        self._R[self._active_index] = R

    @property
    def C(self):
        return self._C[self._active_index]

    @C.setter
    def C(self, C):
        self._C[self._active_index] = C
    
    @property
    def prob(self):
        return self._prob[self._active_index]

    @prob.setter
    def prob(self, prob):
        self._prob[self._active_index] = prob

    @property
    def init_prob(self):
        """Get the initial probability over the current alive base-learners."""
        return self._init_prob[self._active_index]

class Schedule:
    """ 
        abstract class for scheduler. 
    
    """
    def __init__(self, expert_num: int):
        self.active_state = np.zeros(expert_num)
        self.t = 0
        self.exp_num = expert_num
        self.next_k = np.zeros(expert_num)
        self.time_checkpoint = np.zeros(expert_num, dtype=int)
        for k in range(self.exp_num):
            self.next_k[k] = 2**k - 1

    def update_t(self):
        for k in range(self.exp_num):
            if self.active_state[k] == 2:
                self.active_state[k] = 1
            if self.t == self.next_k[k]:
                self.time_checkpoint[k] = self.t
                self.active_state[k] = 2
                self.next_k[k] = self.next_k[k] + 2**k 
        self.t = self.t + 1

    def get_active_state(self):
        return self.active_state

def main_loop(round):

    # construct meta
    expert_num = math.floor(math.log(round,2))+1
    meta = Meta(np.ones(expert_num), expert_num)
    scheduler = Schedule(expert_num)
    loss_scale = 0.01

    # construct model pool
    models = [Model() for i in range(expert_num)]
    for i in range(expert_num):
        models[i] = models[i].to(device)

    # multiprocessing
    mp.set_start_method('spawn')
    mp_queue = mp.Queue()

    # main loop 
    for i in range(round):
        inst_loss = np.zeros(expert_num)

        # schedule
        scheduler.update_t()
        active_state = scheduler.get_active_state()
        meta.update_active_state(active_state)


        # submit the model to the environment
        model_submit_id = meta.sample_expert()
    
        # construct process pool
        processes = []
        for j in range(expert_num):
            if active_state[j]:
                print(f"round {i+1}, expert {j} start at round {scheduler.time_checkpoint[j]}")
                p = mp.Process(target=run_for_each_round, args=(models[j], scheduler.time_checkpoint[j], i+1, j, mp_queue))
                p.start()
                processes.append(p)

        # join all processes
        for p in processes:
            p.join()

        while not mp_queue.empty():
            result = mp_queue.get()
            inst_loss[result[0]] = result[1]    
        inst_loss = inst_loss[np.where(active_state > 0)[0]]        
        ################ save model ################
        inst_loss = inst_loss * loss_scale
        meta_loss = np.mean(inst_loss)
        meta.update_prob(inst_loss, meta_loss)
        print(f'Round {i + 1}/{round}'.center(40,'-'))
        print(f'Loss: {inst_loss}')
        print(f'Prob: {meta.prob}')
        
if __name__ == "__main__":
    round = 200
    main_loop(round)