import numpy as np
import torch 
import random
import copy
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
random.seed(0)
np.random.seed(0)
torch.manual_seed(1) # cpu

def getdata(d0, n):
    input_d = torch.randn(n,d0)
    return input_d

def getMSE(given,test):
    return sum((test-given)**2)/len(test)

def Mixup(x, y, z=0, alpha=1):
    if z == 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = z
        
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
      
    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_y = lam * y + (1 - lam) * y[index, :]
      
    return mixed_x, mixed_y


class Teacher(nn.Module):
    def __init__(self, d0):
        super(Teacher, self).__init__()
        self.fc1 = nn.Linear(d0, 5, bias = False)
        self.fc2 = nn.Linear(5, 1, bias = False)
        
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        return x
    
class Student(nn.Module):

    def __init__(self, d0, d):
        super(Student, self).__init__()
        self.fc3 = nn.Linear(d0, d, bias = False)
        self.fc4 = nn.Linear(d, 1, bias = False)
        
    def forward(self, x):
        x = torch.tanh(self.fc3(x))
        x = self.fc4(x)
        return x

d0=10
n=20
d=100
# lam = 0.5
lr=1e-1
iter=50
epoch=10000
data_x = getdata(d0,n)
test_x = getdata(d0,100)
teacher_net = Teacher(d0)
for param in teacher_net.parameters():
  param.requires_grad = False
#    target = generator(data)
data_y = teacher_net(data_x)
test_y = teacher_net(test_x)

student_net = Student(d0,d)
student_net.fc3.weight.requires_grad = False
student_init = copy.deepcopy(student_net)


#criterion = nn.MSELoss()

# optimizer = torch.optim.SGD(Student.parameters(), lr)
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, student_net.parameters()), lr)



losses = []
eval_losses = []
all_error = []
all_train = []
student_net.load_state_dict(student_init.state_dict())
student_net.fc3.weight.requires_grad = False

for e in range(epoch):
  # mix_x, mix_y = Mixup(data_x, data_y)
  mix_x, mix_y = data_x, data_y
  student_net.train()
  out = student_net(mix_x)
  loss = getMSE(out,mix_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
#    i = i+1
#    losses.append(train_loss / len(data))
  losses.append(loss.item())

  if e%iter==0:
      with torch.no_grad():  
        out_test = student_net(test_x)
        loss_test = getMSE(out_test, test_y)
      all_error.append(loss_test.item())    



# plt.figure()
# plt.title('loss')
# plt.plot(all_error,color = 'blue', linewidth = 3.0, linestyle = '-.', label = 'test')
# plt.plot(losses,color = 'red', linewidth = 3.0, linestyle = '-.', label = 'train')

losses = []
all_mixup_error = []
student_net.load_state_dict(student_init.state_dict())
student_net.fc3.weight.requires_grad = False
for e in range(epoch):
  mix_x, mix_y = Mixup(data_x, data_y)
  student_net.train()
  out = student_net(mix_x)
  loss = getMSE(out,mix_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
#    i = i+1
#    losses.append(train_loss / len(data))
  losses.append(loss.item())

  if e%iter==0:
      with torch.no_grad():  
        out_test = student_net(test_x)
        loss_test = getMSE(out_test, test_y)
      all_mixup_error.append(loss_test.item())    



losses = []
all_fix_error = []
student_net.load_state_dict(student_init.state_dict())
student_net.fc3.weight.requires_grad = False
    
for e in range(epoch):
  mix_x, mix_y = Mixup(data_x, data_y,z=0.5)
  if e == 2500:
        lr /= 10
  elif e == 5000:
        lr /= 10

  for param_group in optimizer.param_groups:
        param_group['lr'] = lr
  student_net.train()
  out = student_net(mix_x)
  loss = getMSE(out,mix_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
#    i = i+1
#    losses.append(train_loss / len(data))
  losses.append(loss.item())
  
  if e%iter==0:
      with torch.no_grad():  
        out_test = student_net(test_x)
        loss_test = getMSE(out_test, test_y)
      all_fix_error.append(loss_test.item())    

plt.figure()
plt.plot(all_mixup_error,color = 'blue', linewidth = 1.0, linestyle = '-.', label = 'Random $\lambda$')
# plt.figure()
plt.plot(all_fix_error,color = 'red', linewidth = 1.0, linestyle = '-.', label = 'Fixed $\lambda=0.5$')
# plt.ylim(0.0025, 0.005)
plt.plot(all_error,color = 'orange', linewidth = 1.0, linestyle = '-.', label = 'ERM')
plt.xlabel("epoch ($\times 50$)")
plt.ylabel('MSE')
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
plt.legend()
plt.grid()
plt.show()


