import random
import numpy as np
import os
import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt
import learn2learn as l2l
from sklearn.manifold import TSNE
import sklearn.metrics as M

from sklearn.decomposition import PCA

os.environ['CUDA_VISIBLE_DEVICES'] = '6'
def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy

def fast_adapt_c(batch, learner, loss, adaptation_steps, shots, ways, device,pgs):
    data, labels = batch
    #print(batch)
    #ss
    data, labels = data.to(device), labels.to(device)

    
    # Separate data into adaptation/evalutation sets
    #print(data.size(0),shots*ways)
    #print(labels)
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    #print(evaluation_indices,adaptation_indices)
    #ss
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)
    '''
    (['features.0.normalize.weight', 'features.0.normalize.bias', 'features.0.normalize.running_mean', 
    'features.0.normalize.running_var', 'features.0.normalize.num_batches_tracked', 'features.0.conv.weight',
      'features.0.conv.bias', 'features.1.normalize.weight', 'features.1.normalize.bias', 'features.1.normalize.running_mean',
        'features.1.normalize.running_var', 'features.1.normalize.num_batches_tracked', 'features.1.conv.weight', 
        'features.1.conv.bias', 'features.2.normalize.weight', 'features.2.normalize.bias', 'features.2.normalize.running_mean',
          'features.2.normalize.running_var', 'features.2.normalize.num_batches_tracked', 'features.2.conv.weight', 
          'features.2.conv.bias', 'features.3.normalize.weight', 'features.3.normalize.bias', 'features.3.normalize.running_mean', 
          'features.3.normalize.running_var', 'features.3.normalize.num_batches_tracked', 'features.3.conv.weight', 
          'features.3.conv.bias', 'classifier.weight', 'classifier.bias'])'''
    params=[p for p in learner.module.parameters() if p.requires_grad]
    partial_emb=0
    if pgs>0:
        pgs=int(pgs)
        for pi in range(pgs):
            if pi==0:
                partial_emb=params[len(params)-pgs+pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
    else:
        for pi in range(len(params)):
            if pi==0:
                partial_emb=params[pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[pi].reshape(1,-1)),-1)
    #print(task_partial_emb.size())
    #ss
    #print(params)
    # Evaluate the adapted model

    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy,partial_emb

def fast_adapt_complete(batch, learner, loss, adaptation_steps, shots, ways, device,pgs):
    data, labels = batch
    #print(batch)
    #ss
    data, labels = data.to(device), labels.to(device)

    
    # Separate data into adaptation/evalutation sets
    #print(data.size(0),shots*ways)
    #print(labels)
    '''adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    #print(evaluation_indices,adaptation_indices)
    #ss
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]'''
    adaptation_data, adaptation_labels = data, labels
    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)
    params=[p for p in learner.module.parameters() if p.requires_grad]
    task_emb=0
    if pgs>0:
        pgs=int(pgs)
        for pi in range(pgs):
            if pi==0:
                task_emb=params[len(params)-pgs+pi].reshape(1,-1)
            else:
                task_emb=torch.cat((task_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
    else:
        for pi in range(len(params)):
            if pi==0:
                task_emb=params[pi].reshape(1,-1)
            else:
                task_emb=torch.cat((task_emb,params[pi].reshape(1,-1)),-1)
    #print(task_partial_emb.size())
    #ss
    #print(params)
    # Evaluate the adapted model

    #predictions = learner(evaluation_data)
    #evaluation_error = loss(predictions, evaluation_labels)
    #evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return None, None,task_emb


class SineTask():
  def __init__(self,amp,phase,min_x,max_x):
    self.phase=phase
    self.max_x=max_x
    self.min_x=min_x
    self.amp=amp
    
  def sample_data(self,size=1):
    x=np.random.uniform(self.max_x,self.min_x,size)
    y=self.true_sine(x)
    x=torch.tensor(x, dtype=torch.float).unsqueeze(1)
    y=torch.tensor(y, dtype=torch.float).unsqueeze(1)
    return x,y
  
  def true_sine(self,x):
    y=self.amp*np.sin(self.phase+x)
    return y

class SineDistribution():
  def __init__(self,min_amp,max_amp,min_phase,max_phase,min_x,max_x):
    self.min_amp=min_amp
    self.max_phase=max_phase
    self.min_phase=min_phase
    self.max_amp=max_amp
    self.min_x=min_x
    self.max_x=max_x
    
  def sample_task(self):
    amp = np.random.uniform(self.min_amp, self.max_amp)
    phase = np.random.uniform(self.min_phase, self.max_phase)
    return SineTask(amp, phase, self.min_x, self.max_x)


#defining our sine-net
class SineNet(nn.Module):
  def __init__(self):
    super(SineNet,self).__init__()
    self.net=nn.Sequential(OrderedDict([
            ('l1',nn.Linear(1,40)),
            ('relu1',nn.ReLU()),
            ('l2',nn.Linear(40,40)),
            ('relu2',nn.ReLU()),
            ('l3',nn.Linear(40,1))
        ]))
      
  # I implemented argforward() so that I could use a set of custom weights for evaluation. 
  # This is important for the "inner loop" in MAML where you temporarily update the weights
  # of the network for a task to calculate the meta-loss and then reset them for the next meta-task.
  def argforward(self,x,weights): 
    x=F.linear(x,weights[0],weights[1])
    x=F.relu(x)
    x=F.linear(x,weights[2],weights[3])           
    x=F.relu(x)
    x=F.linear(x,weights[4],weights[5])
    return x

class SineMAML():
  def __init__(self,net,alpha,beta,tasks,k,num_metatasks,inc,outc):
    self.net=net
    self.weights=list(net.parameters())
    self.alpha=alpha
    self.beta=beta
    self.tasks=tasks
    self.k=k 
    self.num_tasks_meta=num_metatasks
    self.criterion=nn.MSELoss()
    self.meta_optimiser=torch.optim.Adam(self.weights,self.beta)
    self.meta_losses=[]
    self.test_losses=[]
    self.plot_every =100
    self.print_every = 1000
    self.num_metatasks=num_metatasks
    self.inc=inc
    self.outc=outc

  def inner_loop(self,task,delta=0):
    temp_weights=[w.clone() for w in self.weights]
    xs,ys=task.sample_data(size=self.k) #sampling D
    #print(xs.size())#[5,1]
    xs=xs.to(device)
    ys=ys.to(device)
    output=self.net.argforward(xs,temp_weights)
    loss=self.criterion(output,ys)/self.k
    grads=torch.autograd.grad(loss,temp_weights)
    temp_weights=[w-self.alpha*g for w,g in zip(temp_weights,grads)] #temporary update of weights

    #params=[p for p in self.net.parameters() if p.requires_grad]
    partial_emb=[]
    if delta==0:
      for p in temp_weights:
        if len(partial_emb)==0:
          partial_emb=p.reshape(1,-1)
        else:
          partial_emb=torch.cat((partial_emb,p.reshape(1,-1)),-1)
    else:
       for w,g in zip(temp_weights,grads):
        if len(partial_emb)==0:
          partial_emb=g.reshape(1,-1)
        else:
          partial_emb=torch.cat((partial_emb,g.reshape(1,-1)),-1)

    xq,yq=task.sample_data(size=self.k) #sampling D'
    xq=xq.to(device)
    yq=yq.to(device)
    output=self.net.argforward(xq,temp_weights)
    metaloss=self.criterion(output,yq)/self.k

    x=torch.cat((xs,xq),0)
    y=torch.cat((ys,yq),0)
    output=self.net.argforward(xs,temp_weights)
    loss=self.criterion(output,ys)/self.k
    grads=torch.autograd.grad(loss,temp_weights)
    temp_weights=[w-self.alpha*g for w,g in zip(temp_weights,grads)]
    full_emb=[]
    if delta==0:
      for p in temp_weights:
        if len(full_emb)==0:
          full_emb=p.reshape(1,-1)
        else:
          full_emb=torch.cat((full_emb,p.reshape(1,-1)),-1)
    else:
       for w,g in zip(temp_weights,grads):
        if len(full_emb)==0:
          full_emb=g.reshape(1,-1)
        else:
          full_emb=torch.cat((full_emb,g.reshape(1,-1)),-1)
    partial_emb=nn.functional.normalize(partial_emb, p=2.0, dim=-1)   
    full_emb=nn.functional.normalize(full_emb, p=2.0, dim=-1)  
    in_cos= torch.sum(partial_emb*full_emb)
    return metaloss-self.inc*in_cos,full_emb
  
  def inner_loop0(self,task):
    temp_weights=[w.clone() for w in self.weights]
    x,y=task.sample_data(size=self.k) #sampling D
    x=x.to(device)
    y=y.to(device)
    output=self.net.argforward(x,temp_weights)
    loss=self.criterion(output,y)/self.k
    grads=torch.autograd.grad(loss,temp_weights)
    temp_weights=[w-self.alpha*g for w,g in zip(temp_weights,grads)] #temporary update of weights
    x,y=task.sample_data(size=self.k) #sampling D'
    x=x.to(device)
    y=y.to(device)
    output=self.net.argforward(x,temp_weights)
    metaloss=self.criterion(output,y)/self.k
    return metaloss
  
  def test(self,num_tasks=100):
    test_loss=0
    for i in range(num_tasks):
      task=sine_tasks.sample_task()
      test_loss+=self.inner_loop0(task).detach().item()
    return test_loss/num_tasks
  
  def outer_loop(self,num_epochs,save='best_reg.pt',delta=0):
    best_test=999999
    total_loss=0
    for epoch in range(1,num_epochs+1):
      metaloss_sum=0
      emb_matrix=[]
      for i in range(self.num_metatasks):
        task=self.tasks.sample_task()
        metaloss,task_emb=self.inner_loop(task,delta)
        metaloss_sum+=metaloss

        if len(emb_matrix)==0:
           emb_matrix=task_emb
        else:
           emb_matrix=torch.cat((emb_matrix,task_emb),0)
      out_cos=torch.sum(torch.mm(emb_matrix,torch.transpose(emb_matrix,1,0)))/self.num_metatasks
      metaloss_sum+=(self.outc*out_cos)
      metagrads=torch.autograd.grad(metaloss_sum,self.weights)
      #important step
      for w,g in zip(self.weights,metagrads):
        w.grad=g
      ###############
      self.meta_optimiser.step()
      total_loss+=metaloss_sum.item()/self.num_metatasks
      
      if epoch % self.print_every == 0:
        print("{}/{}. loss: {}, best test: {}".format(epoch, num_epochs, total_loss / self.plot_every, np.min(self.test_losses)))
        if np.min(self.test_losses)<best_test:
          best_test=np.min(self.test_losses)
          torch.save(self.net.state_dict(), save)
      if epoch%self.plot_every==0:
        self.test_losses.append(self.test())
        self.meta_losses.append(total_loss/self.plot_every)
        total_loss = 0

device = torch.device('cuda')
sine_tasks=SineDistribution(0.1, 5, 0, np.pi, -5, 5)
net=SineNet()
net=net.to(device)
D=1
ic=0
oc=0.01

save='reg_D'+str(D)+'_best_test_in'+str(ic)+'_out'+str(oc)+'.pt'

maml=SineMAML(net,alpha=0.01,beta=0.001,tasks=sine_tasks,k=10,num_metatasks=32,inc=ic,outc=oc)

maml.outer_loop(num_epochs=50000,save=save,delta=D)
torch.save(maml.net.state_dict(), 'fully_trained'+str(ic)+'_out'+str(oc)+'.pt')
#maml.net.load_state_dict(torch.load(save))
#print(maml.test(1024))
#
#plt.plot(maml.meta_losses)


def test(og_net,x,y,lr,optim=torch.optim.SGD,re=False,delta=0):
  axis=np.linspace(-5,5,1000)
  axis=torch.tensor(axis,dtype=torch.float)
  axis=axis.to(device)
  dummy_net = nn.Sequential(OrderedDict([
        ('l1', nn.Linear(1,40)),
        ('relu1', nn.ReLU()),
        ('l2', nn.Linear(40,40)),
        ('relu2', nn.ReLU()),
        ('l3', nn.Linear(40,1))
    ]))
  dummy_net=dummy_net.to(device)
  dummy_net.load_state_dict(og_net.state_dict())
  loss_fn=nn.MSELoss()
  opt=optim(dummy_net.parameters(),lr=lr)
  num_step=1
  k=x.shape[0]
  losses=[]
  outputs={}
  for i in range(num_step):
    out=dummy_net(x)
    loss=loss_fn(out,y)
    losses.append(loss.item())
    dummy_net.zero_grad()
    loss.backward()
    opt.step()
  out=dummy_net(x)
  loss=loss_fn(out,y)
  losses.append(loss.item())
  outputs['minitrained']= dummy_net(axis.view(-1, 1)).detach().cpu().clone().numpy()
  outputs['initial']= og_net(torch.tensor(axis).view(-1, 1)).detach().cpu().clone().numpy()
  #print(losses)
  if re==False:
    return outputs,axis
  else:
    partial_emb=[]
    temp_weights=list(dummy_net.parameters())
    if delta==0:
      for p in temp_weights:
        if len(partial_emb)==0:
          partial_emb=p.reshape(1,-1)
        else:
          partial_emb=torch.cat((partial_emb,p.reshape(1,-1)),-1)
    else:
       for w,w0 in zip(dummy_net.parameters(),og_net.parameters()):
        if len(partial_emb)==0:
          partial_emb=(w-w0).reshape(1,-1)
        else:
          partial_emb=torch.cat((partial_emb,(w-w0).reshape(1,-1)),-1)
  return partial_emb.detach().cpu().numpy()
     

def plot_test(og_net,x,y,task,optim=torch.optim.SGD,lr=0.001,save='test.png'):
  outputs,axis=test(og_net,x,y,lr,optim)
  plt.figure(figsize=(10,5))
  plt.plot(axis.cpu().clone().numpy(),task.true_sine(axis.cpu().clone().numpy()), '-', color=(0, 0, 1, 0.5), label='true sine')
  plt.scatter(x.cpu().clone().numpy(), y.cpu().clone().numpy(), label='data')
  plt.plot(axis.cpu().clone().numpy(), outputs['initial'], ':', color=(0.7, 0, 0, 1), label='initial')
  plt.plot(axis.cpu().clone().numpy(), outputs['minitrained'], '-', color=(0.5, 0, 0, 1), label='1-step update')
  plt.legend(loc='upper right')
  plt.tight_layout()
  plt.savefig(save,dpi=300)

'''K=10
sine_tasks=SineDistribution(3, 3, 0, 0, -5, 5)
task=sine_tasks.sample_task()
x,y=task.sample_data(K)
x=x.to(device)
y=y.to(device)
maml=SineMAML(net,alpha=0.01,beta=0.001,tasks=sine_tasks,k=10,num_metatasks=32,inc=ic,outc=oc)
maml.net.load_state_dict(torch.load('best_test.pt'))
plot_test(og_net=maml.net.net,x=x,y=y,task=task,save='test.png')
maml=SineMAML(net,alpha=0.01,beta=0.001,tasks=sine_tasks,k=10,num_metatasks=32,inc=ic,outc=oc)
maml.net.load_state_dict(torch.load('best_test_c.pt'))
plot_test(og_net=maml.net.net,x=x,y=y,task=task,save='test_c.png')'''

def get_adp(tn=32,ti=4,K=10,save=None,seed=0):
  adps=[]
  inds=[]
  ts=[]
  sine_tasks=SineDistribution(0.1, 5, 0, np.pi, -5, 5)
  net=SineNet()
  net=net.to(device)
  maml=SineMAML(net,alpha=0.01,beta=0.001,tasks=sine_tasks,k=10,num_metatasks=32,inc=ic,outc=oc)
  maml.net.load_state_dict(torch.load(save))

  '''net=SineNet()
  net=net.to(device)
  mamlc=SineMAML(net,alpha=0.01,beta=0.001,tasks=sine_tasks,k=10,num_metatasks=32,inc=ic,outc=oc)
  mamlc.net.load_state_dict(torch.load('best_test_c.pt'))'''


  
  for i in range(tn):
    task=sine_tasks.sample_task()
    for ii in range(ti):
        x,y=task.sample_data(K)
        x=x.to(device)
        y=y.to(device)
        
        adps.append(test(maml.net.net,x,y,0.001,torch.optim.SGD,True,D))
        #adpsc.append(test(mamlc.net.net,x,y,0.001,torch.optim.SGD,True))
    '''x,y=task.sample_data(2*K)
    x=x.to(device)
    y=y.to(device)
    t=test(maml.net.net,x,y,0.001,torch.optim.SGD,True,D)'''
    
    tt=torch.tensor(np.concatenate(adps,0)[-ti:])
    t=tt.mean(0,True)
    ts.append(t)
    tt=nn.functional.normalize(tt, p=2.0, dim=-1)   
    t=nn.functional.normalize(t, p=2.0, dim=-1)  
    ind=torch.sum(tt*t,-1).numpy()
    inds.append(ind)

  inds=np.concatenate(inds,0)
  adps=np.concatenate(adps,0)
  ts=np.concatenate(ts,0)

  ts=torch.tensor(ts)
  ts=nn.functional.normalize(ts, p=2.0, dim=-1)  
  ods=torch.mm(ts,torch.transpose(ts,1,0))
  idx=(1-torch.eye(tn)).bool()

  ods=ods[idx].numpy()
  label=np.zeros(tn*ti)
  for i in range(tn):

    label[i*ti:(i+1)*ti]=i
  return adps,label,1-inds,1-ods



def plot_adp(adps,tn,ti,save):
    plt.figure()
    tsne = TSNE(n_components=2,perplexity=int(ti-1),metric='cosine')
    tsne.fit_transform(adps)
    for i in range(tn):
       plt.scatter(tsne.embedding_[i*ti:(i+1)*ti,0],tsne.embedding_[i*ti:(i+1)*ti,1])
    
    '''pca=PCA(2)
    pca=pca.fit(adps)
    x=pca.transform(adps)
    for i in range(tn):
       plt.scatter(x[i*ti:(i+1)*ti,0],x[i*ti:(i+1)*ti,1])'''
    ax = plt.gca()
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    plt.tight_layout()
    plt.savefig(save,dpi=300)
    plt.close()

def plot_ds(ds,names,save):
    plt.figure()
    cs=['gray','rosybrown','seagreen','red']
    for d,name,c in zip(ds,names,cs):
      n,limit,_=plt.hist(d, bins=20,density=False,alpha=0)
      l=0.5*(limit[:-1]+limit[1:])
      plt.plot(l,n,color=c,label=name)
      plt.axvline(np.mean(d),linestyle='--',color=c)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save,dpi=300)
    plt.close()


