#    wlchat, a chat client for WhiteLeaf's fork of MemeLabs.
#    Copyright (C) 2022-2025  Alicia <alicia@ion.nu>
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU Affero General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU Affero General Public License for more details.
#
#    You should have received a copy of the GNU Affero General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

import requests
import json
import re
import threading
import websocket
import socket
import sys
import os
import time
import math
import cookies

useragent='wlchat 0.1'
pronouns=[ # Pronoun list
  '0', # 1-indexed, so pad 0
  'he/him',
  'she/her',
  'they/them',
  'he/they',
  'she/they',
  'it/its',
  'any/all',
  'ask/me',
  'name/they'
]

authmethods=[ # Supported authentication methods
  'reddit',
# Experimental, probably broken:
  'google',
  'twitter',
  'twitch',
]

def readform(page, formstr):
  form=re.split(formstr, page)
  if(len(form)<=1): return False
  form=re.split('</form', form[1])[0]
  formdata={}
  for field in re.findall('name="[^"]*" value="[^"]*"', form):
    field=re.split('"', field)
    formdata[field[1]]=field[3]
  return formdata

# The websocket is returned in conn[0]
def joinchat(domain, browser, browserprofile, nick, conn, ui, uictx, authmethod='reddit', localctx={}):
  session=requests.Session()
  # Get whiteleaf config
  try:
    chatinfo=session.get('https://'+domain+'/api/chat/getinfo', headers={'User-agent': useragent})
    if(chatinfo.content.decode()[0]=='{'):
      chatinfo=json.loads(chatinfo.content)
      wsurl='wss://'+chatinfo['chatURL']+'/'+chatinfo['chatKey']
    else:
      chatinfo=session.get('https://'+domain+'/embed/chat', headers={'User-agent': useragent})
      wsurl=chatinfo.content.decode()
      wsurl=wsurl[wsurl.find('data-ws-url="'):]
      wsurl=wsurl.split('"')[1]
  except: # Wait 1s and retry
    time.sleep(1)
    return joinchat(domain, browser, browserprofile, nick, conn, ui, uictx, authmethod)

  headers={}
  if(browser!='' and browser!=None and cookies.validbrowsers.count(browser)>0):
    if(authmethod=='reddit'):
      # Reddit seems to have added a captcha for logins now? So we have to resort to extracting cookies (borrowed that code from yt-dlp)
      cookiesobj=cookies.load_cookies(None, (browser,browserprofile), None)
      cookieslist=cookiesobj.get_cookies_for_url('https://reddit.com')
      for c in cookieslist:
        if(c.name=='reddit_session'):
          session.cookies.set(c.name, c.value, domain='reddit.com')
    elif(authmethod=='google'):
      cookiesobj=cookies.load_cookies(None, (browser,browserprofile), None)
      cookieslist=cookiesobj.get_cookies_for_url('https://www.youtube.com')
      for c in cookieslist:
        if(c.name=='SAPISID' or c.name=='__Secure-3PAPISID' or c.name=='__Secure-1PAPISID'):
          session.cookies.set(c.name, c.value, domain='www.youtube.com')
    elif(authmethod=='twitter'):
      cookiesobj=cookies.load_cookies(None, (browser,browserprofile), None)
      cookieslist=cookiesobj.get_cookies_for_url('https://api.x.com/1.1/')
      for c in cookieslist:
        if(c.name=='auth_token'):
          session.cookies.set(c.name, c.value, domain='x.com')
    elif(authmethod=='twitch'):
      cookiesobj=cookies.load_cookies(None, (browser,browserprofile), None)
      cookieslist=cookiesobj.get_cookies_for_url('https://gql.twitch.tv')
      for c in cookieslist:
        if(c.name=='auth_token'):
          session.cookies.set(c.name, c.value, domain='gql.twitch.tv')
# TODO: Do the same for discord
    # Request whiteleaf auth
    x=session.post('https://'+domain+'/login', headers={'User-agent': useragent}, params={'authProvider': authmethod})
    # Find the form, or fall back on read-only chat
    auth=False
    if(authmethod=='reddit'):
      formdata=readform(x.content.decode(), 'action="/api/v1/authorize"')
      if(formdata):
        x=session.post('https://ssl.reddit.com/api/v1/authorize', headers={'User-agent': useragent}, params=formdata)
        auth=True
    # TODO: Deal with equivalent forms on other platforms

    if(auth):
      # Check for new account form
      formdata=readform(x.content.decode(), '<form action="/register"')
      if(formdata and ui.get('register')):
        page=x.content.decode().split('<div class=\'catchImage\'>')
        captcha=re.findall('data:[^\'"]*', page[1])[0]
        formdata['username']=nick
        ui['register'](uictx, domain, formdata, captcha)
        # 'register' needs to set 'username', 'catch' (the captcha answer), and 'agreement' (to 'on')
        x=session.post('https://'+domain+'/register', headers={'User-agent': useragent}, params=formdata)
      elif(x.content.decode().find('<div class="alert alert-danger')>=0):
        err=x.content.decode()
        start=err.find('<div class="alert alert-danger')
        err=err[:err.find('</div>',start)]
        err=err[err.rfind('>')+1:]
        ui['notice'](uictx, domain, nick, 'Error: '+err)

      cookieslist=session.cookies.get_dict()
      headers['Cookie']=''
      for c in cookieslist:
        if(c.count('reddit')>0): continue # Skip auth source cookies
        if(headers['Cookie']!=''): headers['Cookie']=headers['Cookie']+';'
        headers['Cookie']=headers['Cookie']+c+'='+cookieslist[c]

  localctx['namelist']=[] # To deal with bans and renames mostly
  def readws(c, msg):
    namelist=localctx['namelist']
    chan=domain
    if(msg.startswith('SMSG ')): # It's just a message, isn't it?
      msg='MSG '+msg[5:]
    if(msg.startswith('CMSG ')):
      msg='MSG '+msg[5:]
    if(msg.startswith('SMUTE ')): # Just mute, right?
      msg='MUTE '+msg[6:]
    if(msg.startswith('SBAN ')): # Just ban, right?
      msg='BAN '+msg[5:]
    if(msg.startswith('AWARE ')):
      data=json.loads(msg[6:])
      if(data['data']=='ping'):
        c.send(msg.replace('ping','pong'))
      else:
        try: dd=json.loads(data['data'])
        except Exception: dd=False
        if(dd and dd.get('embedStatus') and dd.get('embedPlatform') and dd.get(dd['embedStatus'][0]+'Info')):
          info=dd[dd['embedStatus'][0]+'Info']
          value=info[dd['embedPlatform'][0]]
          if(dd['embedPlatform'][0]=='twitchChannel'):
            ui['notice'](uictx, chan, nick, dd['embedStatus'][0]+': https://twitch.tv/'+value)
          elif(dd['embedPlatform'][0]=='youtubeChannel'):
            ui['notice'](uictx, chan, nick, dd['embedStatus'][0]+': https://youtube.com/embed/live_stream?channel='+value)
          elif(dd['embedPlatform'][0]=='youtubeVOD'):
            ui['notice'](uictx, chan, nick, dd['embedStatus'][0]+': https://youtube.com/watch?v='+value)
          else:
            print('Unknown embedStatus: '+dd['embedStatus'][0]+', '+value)
        else:
          print('Unknown AWARE: '+msg)
    elif(msg.startswith('MSG ')):
      data=json.loads(msg[4:])
      ui['msg'](uictx, chan, nick, data)
    elif(msg.startswith('BROADCAST ')):
      data=json.loads(msg[10:])
      ui['broadcast'](uictx, chan, nick, data)
    elif(msg.startswith('PRIVMSG ')):
      data=json.loads(msg[8:])
      ui['privmsg'](uictx, chan, nick, data)
    elif(msg.startswith('NAMES ')):
      data=json.loads(msg[6:])
      ui['namelist'](uictx, chan, nick, data['users'])
      namelist.clear()
      for user in data['users']:
        namelist.append(user['nick'])
    elif(msg.startswith('JOIN ')):
      data=json.loads(msg[5:])
      ui['join'](uictx, chan, nick, data)
      namelist.append(data['nick'])
    elif(msg.startswith('QUIT ')):
      data=json.loads(msg[5:])
      ui['quit'](uictx, chan, nick, data)
      if(namelist.count(data['nick'])):
        namelist.remove(data['nick'])
    elif(msg.startswith('ERR ')):
      ui['notice'](uictx, chan, nick, msg)
    elif(msg.startswith('RENAME ')):
      data=json.loads(msg[7:])
      newnamelist=[]
      for user in data['users']:
        if(namelist.count(user['nick'])==0):
          namelist.append(user['nick'])
          ui['join'](uictx, chan, nick, user)
        newnamelist.append(user['nick'])
      for user in namelist:
        if(newnamelist.count(user)==0 and namelist.count(user)):
          namelist.remove(user)
          ui['quit'](uictx, chan, nick, {'nick':user})
    elif(msg.startswith('BAN ')):
      data=json.loads(msg[4:])
      ui['notice'](uictx, chan, nick, data['nick']+' bans '+data['data']+' for '+str(int(data['extradata'])/1000000000)+' seconds')
    elif(msg.startswith('MUTE ')):
      data=json.loads(msg[5:])
      ui['notice'](uictx, chan, nick, data['nick']+' mutes '+data['data']+' for '+str(int(data['extradata'])/1000000000)+' seconds')
    elif(msg.startswith('UNMUTE ')):
      data=json.loads(msg[7:])
      ui['notice'](uictx, chan, nick, data['nick']+' unmutes '+data['data'])
    elif(msg.startswith('MSGPIN ') and ui.get('topic')):
      ui['topic'](uictx, chan, nick, msg[7:])
    else:
      print('Unknown message: '+msg)
  def onclose(x,y,z):
    print('onclose')
    print(y)
    print(z)
    x.on_close=None
    x.on_message=None
    x.close()
    # TODO: Use the notice callback for these?
    print('Connection closed, attempting to reconnect ('+domain+')')
    joinchat(domain, browser, browserprofile, nick, conn, ui, uictx, authmethod)
  def onerror(x,y):
    print('onerror')
    print(y)
    x.on_close=None
    x.on_message=None
    x.close()
    print('Connection closed (onerror), attempting to reconnect ('+domain+')')
    joinchat(domain, browser, browserprofile, nick, conn, ui, uictx, authmethod)
  ws=websocket.WebSocketApp(wsurl, header=headers, on_message=readws, on_close=onclose, on_error=onerror)
  conn[0]=ws
  def run_forever(x):
    ws.run_forever()
  thread=threading.Thread(target=run_forever, args={'reconnect':5})
  thread.daemon=True
  thread.start()
  return True

def fetch_emotes(domain):
  # Try whiteleaf path
  emotes=requests.get('https://'+domain+'/for/apiAssets/static/images/emotes/info.json', headers={'User-agent': useragent})
  if(not emotes.ok): # Try other path (original memelabs? idk)
    emotes=requests.get('https://'+domain+'/embed/chat/', headers={'User-agent': useragent})
    if(not emotes.ok): return {}
    try:
      page=emotes.content.decode()
      cdn=page.index('data-cdn=')
      cdn=page[cdn+10:page.index(' ',cdn)-1]
      emotes=requests.get(cdn+'/emotes/emotes.json', headers={'User-agent': useragent})
    except: return {}
  emotelist={}
  if(emotes.ok):
    emotes=json.loads(emotes.content)
    for emote in emotes:
      if(not isinstance(emote, dict)): emote=emotes[emote]
      if(not isinstance(emote, dict)):
        continue
      if(not emote.get('prefix')):
        continue
      if(not emote.get('imageWidth')):
        emote['imageWidth']=emote['image'][0]['width']
        emote['imageHeight']=emote['image'][0]['height']
      emotelist[emote['prefix']]=emote
  return emotelist

def emote_mkflag(colors):
  import PIL.ImageFile
  img=PIL.Image.new('RGBA', (1,len(colors)))
  y=0
  for color in colors:
    r=int(color/0x10000)%0x100
    g=int(color/0x100)%0x100
    b=int(color)%0x100
    img.putpixel((0,y), (r,g,b))
    y+=1
  return img

def emotemods(img, mods, frame=0):
  import PIL.ImageFile
  # :flip turns upside down
  # :mirror mirrors
  # :wide scales to 150% width
  # :cancelled grays out and puts an X over
  # :pride puts a rainbow flag in the background, presumably
  # :bi
  # :trans do eqivalent
  # :dank adds a rainbowy multi-striped V in the background
  # :spin adds a rotation animation, seems to be ~60 RPM clockwise
  # :frozen overlays an icecube and shifts color channels to BRGA
  # :snow
  # TODO: Handle remaining modifiers:
  # :worth seems to cycle through color hues reciprocally?
  # :rustle seems to make the emote shake side to side 2 or 3 pixels each
  # :love adds hearts on the sides, spawning and floating up
  # :hop
  img=img.convert('RGBA')
  if(mods.count('wide')>0):
    img=img.resize((int(img.size[0]*1.5), img.size[1]))
  if(mods.count('mirror')>0):
    img=img.transpose(PIL.Image.Transpose.FLIP_LEFT_RIGHT)
  if(mods.count('flip')>0):
    img=img.transpose(PIL.Image.Transpose.FLIP_TOP_BOTTOM)
  if(mods.count('cancelled')>0): # TODO: Add an X
    img=img.convert('L').convert('RGBA')
  if(mods.count('trans')>0): # Colors helpfully provided by https://www.flagcolorcodes.com/
    flag=emote_mkflag([0x5BCEFA, 0xF5A9B8, 0xFFFFFF, 0xF5A9B8, 0x5BCEFA])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('enby')>0):
    flag=emote_mkflag([0xFCF434, 0xFFFFFF, 0x9C59D1, 0x2C2C2C])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('bi')>0):
    flag=emote_mkflag([0xD60270, 0xD60270, 0x9B4F96, 0x0038A8, 0x0038A8])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('lesbian')>0):
    flag=emote_mkflag([0xD52D00, 0xEF7627, 0xFF9A56, 0xFFFFFF, 0xD162A4, 0xB55690, 0xA30262])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('asex')>0):
    flag=emote_mkflag([0x000000, 0xA3A3A3, 0xFFFFFF, 0x800080])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('genderfluid')>0):
    flag=emote_mkflag([0xFF76A4, 0xFFFFFF, 0xC011D7, 0x000000, 0x2F3CBE])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('pan')>0):
    flag=emote_mkflag([0xFF218C, 0xFFD800, 0x21B1FF])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('pride')>0):
    flag=emote_mkflag([0xE40303, 0xFF8C00, 0xFFED00, 0x008026, 0x24408E, 0x732982])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('mlm')>0):
    flag=emote_mkflag([0x078D70, 0x26CEAA, 0x98E8C1, 0xFFFFFF, 0x7BADE2, 0x5049CC, 0x3D1A78])
    flag=flag.resize(img.size, resample=PIL.Image.Resampling.NEAREST)
    flag.alpha_composite(img)
    img=flag
  if(mods.count('spin')>0):
    img=img.rotate(-frame*360, resample=PIL.Image.Resampling.BICUBIC)
  if(mods.count('rustle')>0):
    rps=5 # Rustles per second
    rx=4 # Amount of rustle movement
    img=img.crop(((frame*rps*rx)%rx-rx/2,0,img.width+(frame*rps*rx)%rx-rx/2,img.height))
  if(mods.count('dank')>0):
    dank=PIL.Image.open('dank.png') # TODO: Better path
    x=int((dank.width-img.width)/2)
    y=int((dank.height-img.height)/2)
    dank.alpha_composite(img, dest=(x,y))
    img=dank
  if(mods.count('frozen')>0):
    r=img.getchannel('R')
    g=img.getchannel('G')
    b=img.getchannel('B')
    a=img.getchannel('A')
    img=PIL.Image.merge('RGBA', (b,r,g,a))
    if(mods.count('love')==0): # Love removes icecube? idk
      frozen=PIL.Image.open('frozen.png') # TODO: Better path
      frozen=frozen.resize((img.width, img.height))
      img.alpha_composite(frozen, dest=(0,0))
  if(mods.count('snow')>0):
    snow=PIL.Image.open('snow.png') # TODO: Better path
    snow=snow.resize((img.width, snow.height))
    img.alpha_composite(snow, dest=(0,int(-snow.height+(snow.height+img.height)*frame)))
  return img

def get_emote_img(emote, emotelist, domain):
  import PIL.ImageFile
  e=emotelist[emote['name']]
  if(e.get('id')):
    path=os.environ['HOME']+'/.cache/wlchat/'+e['id']+'.png'
  else:
    path=os.environ['HOME']+'/.cache/wlchat/'+e['image'][0]['name']
  if(emote.get('modifiers')):
    mods=emote['modifiers'].copy()
    for mod in mods: # Avoid duplicates
      if(mods.count(mod)>1): mods.remove(mod)
    mods.sort()
    path=path[0:path.rindex('.')]+(''.join(mods)+path[path.rindex('.'):]).replace('/','')
  if(not os.path.exists(path)):
    if(not os.path.exists(os.environ['HOME']+'/.cache/wlchat')):
      os.mkdir(os.environ['HOME']+'/.cache/wlchat')
    # TODO: If modified, reuse local base instead of fetching again? Though that's complicated by the fact we turn animated spritesheets into gifs
    if(e.get('id')):
      r=requests.get('https://'+domain+'/for/apiAssets/stage/images/emotes/aaa/1/'+e['id']+'.png', headers={'User-agent': useragent})
    else:
      r=requests.get(e['image'][0]['url'], headers={'User-agent': useragent})
    if(r.status_code==200):
      if(e.get('animated')): # Turn sprite into animated gif
        parser=PIL.ImageFile.Parser()
        parser.feed(r.content)
        img=parser.close()
        frames=[]
        i=0
        sec=int(e['animationLength']/1000) # For long animations, do animationmods multiple times
        if(sec==0): sec=1 # Do at least one, even if fast
        framecount=img.size[0]/e['imageWidth']
        # For animation modifiers, ensure that we have enough frames for a smooth animation
        framemultiplier=1
        if(emote.get('modifiers') and emote['modifiers'].count('spin')>0 and framecount<12):
          framemultiplier=math.ceil(12/framecount)
        while(i<framecount*framemultiplier):
          j=i/framemultiplier
          jf=math.floor(j)
          frame=img.crop((jf*e['imageWidth'],0,(jf+1)*e['imageWidth'],e['imageHeight']))
          if(emote.get('modifiers')):
            frame=emotemods(frame, emote['modifiers'], (j*sec/framecount)%1)
          frames.append(frame)
          i+=1
        if(e['animationFrames']==0): e['animationFrames']=1 # No /0
        duration=e['animationLength']/e['animationFrames']/framemultiplier # total length to frame length (milliseconds)
        frames[0].save(path, format='gif', save_all=True, append_images=frames[1:], optimize=False, loop=0, disposal=2, duration=duration)
      elif(emote.get('modifiers')): # Not animated, but we still need to apply modifiers
        parser=PIL.ImageFile.Parser()
        parser.feed(r.content)
        img=parser.close()
        if(emote['modifiers'].count('spin')<1 and emote['modifiers'].count('rustle')<1 and emote['modifiers'].count('snow')<1 and emote['modifiers'].count('love')<1):
          img=emotemods(img, emote['modifiers'])
          img.save(path)
        else: # Spin and rustle needs to make it animated
          i=0
          frames=[]
          while(i<36):
            frame=emotemods(img, emote['modifiers'], i/36)
            frames.append(frame)
            i+=1
          frames[0].save(path, format='gif', save_all=True, append_images=frames[1:], optimize=False, loop=0, disposal=2, duration=1000/36)
      else:
        f=open(path, 'wb')
        f.write(r.content)
        f.close()
  return path

def findemotes(data, emotes):
  pos=0
  if(data.get('nodes') and data['nodes'].get('emotes')):
    data['nodes']['emotes']=[]
  for word in data['data'].replace('\n',' ').split(' '):
    wordlen=len(word)
    word=word.split(':')
    if(emotes.get(word[0])):
      if(not data.get('nodes')): data['nodes']={}
      if(not data['nodes'].get('emotes')): data['nodes']['emotes']=[]
      e={'name':word[0],'bounds':[pos,pos+wordlen]}
      if(len(word)>1):
        e['modifiers']=word[1:]
        e['modifiers'].sort()
      data['nodes']['emotes'].append(e)
    pos+=wordlen+1
